mirror of https://github.com/apache/superset.git
chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
This commit is contained in:
parent
7d7ce63970
commit
a4d5d7c6b9
|
@ -15,14 +15,28 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
repos:
|
||||
- repo: https://github.com/MarcoGorelli/auto-walrus
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: auto-walrus
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.4.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args:
|
||||
- --py39-plus
|
||||
- repo: https://github.com/hadialqattan/pycln
|
||||
rev: v2.1.2
|
||||
hooks:
|
||||
- id: pycln
|
||||
args:
|
||||
- --disable-all-dunder-policy
|
||||
- --exclude=superset/config.py
|
||||
- --extend-exclude=tests/integration_tests/superset_test_config.*.py
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/MarcoGorelli/auto-walrus
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: auto-walrus
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
|
|
|
@ -17,8 +17,9 @@ import csv as lib_csv
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import click
|
||||
from click.core import Context
|
||||
|
@ -67,15 +68,15 @@ class GitChangeLog:
|
|||
def __init__(
|
||||
self,
|
||||
version: str,
|
||||
logs: List[GitLog],
|
||||
logs: list[GitLog],
|
||||
access_token: Optional[str] = None,
|
||||
risk: Optional[bool] = False,
|
||||
) -> None:
|
||||
self._version = version
|
||||
self._logs = logs
|
||||
self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
|
||||
self._github_login_cache: Dict[str, Optional[str]] = {}
|
||||
self._github_prs: Dict[int, Any] = {}
|
||||
self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
|
||||
self._github_login_cache: dict[str, Optional[str]] = {}
|
||||
self._github_prs: dict[int, Any] = {}
|
||||
self._wait = 10
|
||||
github_token = access_token or os.environ.get("GITHUB_TOKEN")
|
||||
self._github = Github(github_token)
|
||||
|
@ -126,7 +127,7 @@ class GitChangeLog:
|
|||
"superset/migrations/versions/" in file.filename for file in commit.files
|
||||
)
|
||||
|
||||
def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
|
||||
def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
|
||||
pr_number = git_log.pr_number
|
||||
if pr_number:
|
||||
detail = self._pr_logs_with_details.get(pr_number)
|
||||
|
@ -156,7 +157,7 @@ class GitChangeLog:
|
|||
|
||||
return detail
|
||||
|
||||
def _is_risk_pull_request(self, labels: List[Any]) -> bool:
|
||||
def _is_risk_pull_request(self, labels: list[Any]) -> bool:
|
||||
for label in labels:
|
||||
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
|
||||
if risk_label is not None:
|
||||
|
@ -174,8 +175,8 @@ class GitChangeLog:
|
|||
|
||||
def _parse_change_log(
|
||||
self,
|
||||
changelog: Dict[str, str],
|
||||
pr_info: Dict[str, str],
|
||||
changelog: dict[str, str],
|
||||
pr_info: dict[str, str],
|
||||
github_login: str,
|
||||
) -> None:
|
||||
formatted_pr = (
|
||||
|
@ -227,7 +228,7 @@ class GitChangeLog:
|
|||
result += f"**{key}** {changelog[key]}\n"
|
||||
return result
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
||||
def __iter__(self) -> Iterator[dict[str, Any]]:
|
||||
for log in self._logs:
|
||||
yield {
|
||||
"pr_number": log.pr_number,
|
||||
|
@ -250,20 +251,20 @@ class GitLogs:
|
|||
|
||||
def __init__(self, git_ref: str) -> None:
|
||||
self._git_ref = git_ref
|
||||
self._logs: List[GitLog] = []
|
||||
self._logs: list[GitLog] = []
|
||||
|
||||
@property
|
||||
def git_ref(self) -> str:
|
||||
return self._git_ref
|
||||
|
||||
@property
|
||||
def logs(self) -> List[GitLog]:
|
||||
def logs(self) -> list[GitLog]:
|
||||
return self._logs
|
||||
|
||||
def fetch(self) -> None:
|
||||
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
|
||||
|
||||
def diff(self, git_logs: "GitLogs") -> List[GitLog]:
|
||||
def diff(self, git_logs: "GitLogs") -> list[GitLog]:
|
||||
return [log for log in git_logs.logs if log not in self._logs]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -284,7 +285,7 @@ class GitLogs:
|
|||
print(f"Could not checkout {git_ref}")
|
||||
sys.exit(1)
|
||||
|
||||
def _git_logs(self) -> List[str]:
|
||||
def _git_logs(self) -> list[str]:
|
||||
# let's get current git ref so we can revert it back
|
||||
current_git_ref = self._git_get_current_head()
|
||||
self._git_checkout(self._git_ref)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from click.core import Context
|
||||
|
||||
|
@ -34,7 +34,7 @@ PROJECT_MODULE = "superset"
|
|||
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
|
||||
|
||||
|
||||
def string_comma_to_list(message: str) -> List[str]:
|
||||
def string_comma_to_list(message: str) -> list[str]:
|
||||
if not message:
|
||||
return []
|
||||
return [element.strip() for element in message.split(",")]
|
||||
|
@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str:
|
|||
return template.render(kwargs)
|
||||
|
||||
|
||||
class BaseParameters(object):
|
||||
class BaseParameters:
|
||||
def __init__(
|
||||
self,
|
||||
version: str,
|
||||
|
@ -60,7 +60,7 @@ class BaseParameters(object):
|
|||
) -> None:
|
||||
self.version = version
|
||||
self.version_rc = version_rc
|
||||
self.template_arguments: Dict[str, Any] = {}
|
||||
self.template_arguments: dict[str, Any] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Apache Credentials: {self.version}/{self.version_rc}"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#
|
||||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from cachelib.file import FileSystemCache
|
||||
|
@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
|
|||
error_msg = "The environment variable {} was missing, abort...".format(
|
||||
var_name
|
||||
)
|
||||
raise EnvironmentError(error_msg)
|
||||
raise OSError(error_msg)
|
||||
|
||||
|
||||
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
|
||||
|
@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT")
|
|||
DATABASE_DB = get_env_variable("DATABASE_DB")
|
||||
|
||||
# The SQLAlchemy connection string.
|
||||
SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
|
||||
SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
|
||||
DATABASE_DIALECT,
|
||||
DATABASE_USER,
|
||||
DATABASE_PASSWORD,
|
||||
|
@ -80,7 +79,7 @@ CACHE_CONFIG = {
|
|||
DATA_CACHE_CONFIG = CACHE_CONFIG
|
||||
|
||||
|
||||
class CeleryConfig(object):
|
||||
class CeleryConfig:
|
||||
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
|
||||
imports = ("superset.sql_lab",)
|
||||
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
|
||||
|
|
|
@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
|
|||
from inspect import getsource
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
|
@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
|
|||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
raise Exception(
|
||||
"No module spec found in location: `{path}`".format(path=str(filepath))
|
||||
)
|
||||
raise Exception(f"No module spec found in location: `{str(filepath)}`")
|
||||
|
||||
|
||||
def extract_modified_tables(module: ModuleType) -> Set[str]:
|
||||
def extract_modified_tables(module: ModuleType) -> set[str]:
|
||||
"""
|
||||
Extract the tables being modified by a migration script.
|
||||
|
||||
|
@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
|
|||
actually traversing the AST.
|
||||
"""
|
||||
|
||||
tables: Set[str] = set()
|
||||
tables: set[str] = set()
|
||||
for function in {"upgrade", "downgrade"}:
|
||||
source = getsource(getattr(module, function))
|
||||
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
|
||||
|
@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
|
|||
return tables
|
||||
|
||||
|
||||
def find_models(module: ModuleType) -> List[Type[Model]]:
|
||||
def find_models(module: ModuleType) -> list[type[Model]]:
|
||||
"""
|
||||
Find all models in a migration script.
|
||||
"""
|
||||
models: List[Type[Model]] = []
|
||||
models: list[type[Model]] = []
|
||||
tables = extract_modified_tables(module)
|
||||
|
||||
# add models defined explicitly in the migration script
|
||||
|
@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
|
|||
sorter: TopologicalSorter[Any] = TopologicalSorter()
|
||||
for model in models:
|
||||
inspector = inspect(model)
|
||||
dependent_tables: List[str] = []
|
||||
dependent_tables: list[str] = []
|
||||
for column in inspector.columns.values():
|
||||
for foreign_key in column.foreign_keys:
|
||||
if foreign_key.column.table.name != model.__tablename__:
|
||||
|
@ -174,7 +172,7 @@ def main(
|
|||
|
||||
print("\nIdentifying models used in the migration:")
|
||||
models = find_models(module)
|
||||
model_rows: Dict[Type[Model], int] = {}
|
||||
model_rows: dict[type[Model], int] = {}
|
||||
for model in models:
|
||||
rows = session.query(model).count()
|
||||
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
||||
|
@ -182,7 +180,7 @@ def main(
|
|||
session.close()
|
||||
|
||||
print("Benchmarking migration")
|
||||
results: Dict[str, float] = {}
|
||||
results: dict[str, float] = {}
|
||||
start = time.time()
|
||||
upgrade(revision=revision)
|
||||
duration = time.time() - start
|
||||
|
@ -190,14 +188,14 @@ def main(
|
|||
print(f"Migration on current DB took: {duration:.2f} seconds")
|
||||
|
||||
min_entities = 10
|
||||
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
|
||||
new_models: dict[type[Model], list[Model]] = defaultdict(list)
|
||||
while min_entities <= limit:
|
||||
downgrade(revision=down_revision)
|
||||
print(f"Running with at least {min_entities} entities of each model")
|
||||
for model in models:
|
||||
missing = min_entities - model_rows[model]
|
||||
if missing > 0:
|
||||
entities: List[Model] = []
|
||||
entities: list[Model] = []
|
||||
print(f"- Adding {missing} entities to the {model.__name__} model")
|
||||
bar = ChargingBar("Processing", max=missing)
|
||||
try:
|
||||
|
|
|
@ -33,13 +33,13 @@ Example:
|
|||
./cancel_github_workflows.py 1024 --include-last
|
||||
"""
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import click
|
||||
import requests
|
||||
from click.exceptions import ClickException
|
||||
from dateutil import parser
|
||||
from typing_extensions import Literal
|
||||
|
||||
github_token = os.environ.get("GITHUB_TOKEN")
|
||||
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
||||
|
@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
|||
|
||||
def request(
|
||||
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
resp = requests.request(
|
||||
method,
|
||||
f"https://api.github.com/{endpoint.lstrip('/')}",
|
||||
|
@ -61,8 +61,8 @@ def request(
|
|||
|
||||
def list_runs(
|
||||
repo: str,
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
params: Optional[dict[str, str]] = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""List all github workflow runs.
|
||||
Returns:
|
||||
An iterator that will iterate through all pages of matching runs."""
|
||||
|
@ -77,16 +77,15 @@ def list_runs(
|
|||
params={**params, "per_page": 100, "page": page},
|
||||
)
|
||||
total_count = result["total_count"]
|
||||
for item in result["workflow_runs"]:
|
||||
yield item
|
||||
yield from result["workflow_runs"]
|
||||
page += 1
|
||||
|
||||
|
||||
def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
|
||||
def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
|
||||
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
|
||||
|
||||
|
||||
def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
|
||||
def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
|
||||
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
|
||||
|
||||
|
||||
|
@ -96,7 +95,7 @@ def get_runs(
|
|||
user: Optional[str] = None,
|
||||
statuses: Iterable[str] = ("queued", "in_progress"),
|
||||
events: Iterable[str] = ("pull_request", "push"),
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get workflow runs associated with the given branch"""
|
||||
return [
|
||||
item
|
||||
|
@ -108,7 +107,7 @@ def get_runs(
|
|||
]
|
||||
|
||||
|
||||
def print_commit(commit: Dict[str, Any], branch: str) -> None:
|
||||
def print_commit(commit: dict[str, Any], branch: str) -> None:
|
||||
"""Print out commit message for verification"""
|
||||
indented_message = " \n".join(commit["message"].split("\n"))
|
||||
date_str = (
|
||||
|
@ -155,7 +154,7 @@ Date: {date_str}
|
|||
def cancel_github_workflows(
|
||||
branch_or_pull: Optional[str],
|
||||
repo: str,
|
||||
event: List[str],
|
||||
event: list[str],
|
||||
include_last: bool,
|
||||
include_running: bool,
|
||||
) -> None:
|
||||
|
|
|
@ -24,7 +24,7 @@ def cleanup_permissions() -> None:
|
|||
pvms = security_manager.get_session.query(
|
||||
security_manager.permissionview_model
|
||||
).all()
|
||||
print("# of permission view menus is: {}".format(len(pvms)))
|
||||
print(f"# of permission view menus is: {len(pvms)}")
|
||||
pvms_dict = defaultdict(list)
|
||||
for pvm in pvms:
|
||||
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
|
||||
|
@ -43,7 +43,7 @@ def cleanup_permissions() -> None:
|
|||
pvms = security_manager.get_session.query(
|
||||
security_manager.permissionview_model
|
||||
).all()
|
||||
print("Stage 1: # of permission view menus is: {}".format(len(pvms)))
|
||||
print(f"Stage 1: # of permission view menus is: {len(pvms)}")
|
||||
|
||||
# 2. Clean up None permissions or view menus
|
||||
pvms = security_manager.get_session.query(
|
||||
|
@ -57,7 +57,7 @@ def cleanup_permissions() -> None:
|
|||
pvms = security_manager.get_session.query(
|
||||
security_manager.permissionview_model
|
||||
).all()
|
||||
print("Stage 2: # of permission view menus is: {}".format(len(pvms)))
|
||||
print(f"Stage 2: # of permission view menus is: {len(pvms)}")
|
||||
|
||||
# 3. Delete empty permission view menus from roles
|
||||
roles = security_manager.get_session.query(security_manager.role_model).all()
|
||||
|
|
6
setup.py
6
setup.py
|
@ -14,21 +14,19 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
||||
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
|
||||
|
||||
with open(PACKAGE_JSON, "r") as package_file:
|
||||
with open(PACKAGE_JSON) as package_file:
|
||||
version_string = json.load(package_file)["version"]
|
||||
|
||||
with io.open("README.md", "r", encoding="utf-8") as f:
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import ipaddress
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column
|
||||
|
||||
|
@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
|
|||
|
||||
# Make this return a single clause
|
||||
def cidr_translate_filter_func(
|
||||
col: Column, operator: FilterOperator, values: List[Any]
|
||||
col: Column, operator: FilterOperator, values: list[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Convert a passed in column, FilterOperator and
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import itertools
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column
|
||||
|
||||
|
@ -26,7 +26,7 @@ from superset.advanced_data_type.types import (
|
|||
)
|
||||
from superset.utils.core import FilterOperator, FilterStringOperators
|
||||
|
||||
port_conversion_dict: Dict[str, List[int]] = {
|
||||
port_conversion_dict: dict[str, list[int]] = {
|
||||
"http": [80],
|
||||
"ssh": [22],
|
||||
"https": [443],
|
||||
|
@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo
|
|||
|
||||
|
||||
def port_translate_filter_func(
|
||||
col: Column, operator: FilterOperator, values: List[Any]
|
||||
col: Column, operator: FilterOperator, values: list[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Convert a passed in column, FilterOperator
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, TypedDict, Union
|
||||
from typing import Any, Callable, Optional, TypedDict, Union
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.sql.expression import BinaryExpression
|
||||
|
@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict):
|
|||
"""
|
||||
|
||||
advanced_data_type: str
|
||||
values: List[
|
||||
values: list[
|
||||
Union[FilterValues, None]
|
||||
] # unparsed value (usually text when passed from text box)
|
||||
|
||||
|
@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False):
|
|||
"""
|
||||
|
||||
error_message: Optional[str]
|
||||
values: List[Any] # parsed value (can be any value)
|
||||
values: list[Any] # parsed value (can be any value)
|
||||
display_value: str # The string representation of the parsed values
|
||||
valid_filter_operators: List[FilterStringOperators]
|
||||
valid_filter_operators: list[FilterStringOperators]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -54,6 +54,6 @@ class AdvancedDataType:
|
|||
|
||||
verbose_name: str
|
||||
description: str
|
||||
valid_data_types: List[str]
|
||||
valid_data_types: list[str]
|
||||
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
|
||||
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask import request, Response
|
||||
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
|
||||
|
@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
|
|||
|
||||
@staticmethod
|
||||
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
|
||||
layer_id: int, rison_parameters: Dict[str, Any]
|
||||
layer_id: int, rison_parameters: dict[str, Any]
|
||||
) -> None:
|
||||
if "filters" not in rison_parameters:
|
||||
rison_parameters["filters"] = []
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from superset.annotation_layers.annotations.commands.exceptions import (
|
||||
AnnotationBulkDeleteFailedError,
|
||||
|
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BulkDeleteAnnotationCommand(BaseCommand):
|
||||
def __init__(self, model_ids: List[int]):
|
||||
def __init__(self, model_ids: list[int]):
|
||||
self._model_ids = model_ids
|
||||
self._models: Optional[List[Annotation]] = None
|
||||
self._models: Optional[list[Annotation]] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateAnnotationCommand(BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
|
@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand):
|
|||
return annotation
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
layer_id: Optional[int] = self._properties.get("layer")
|
||||
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
|
||||
end_dttm: Optional[datetime] = self._properties.get("end_dttm")
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateAnnotationCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Annotation] = None
|
||||
|
@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand):
|
|||
return annotation
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
layer_id: Optional[int] = self._properties.get("layer")
|
||||
short_descr: str = self._properties.get("short_descr", "")
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO):
|
|||
model_cls = Annotation
|
||||
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None:
|
||||
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
try:
|
||||
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from superset.annotation_layers.commands.exceptions import (
|
||||
AnnotationLayerBulkDeleteFailedError,
|
||||
|
@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BulkDeleteAnnotationLayerCommand(BaseCommand):
|
||||
def __init__(self, model_ids: List[int]):
|
||||
def __init__(self, model_ids: list[int]):
|
||||
self._model_ids = model_ids
|
||||
self._models: Optional[List[AnnotationLayer]] = None
|
||||
self._models: Optional[list[AnnotationLayer]] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateAnnotationLayerCommand(BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
|
@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
|
|||
return annotation_layer
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
|
||||
name = self._properties.get("name", "")
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateAnnotationLayerCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[AnnotationLayer] = None
|
||||
|
@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand):
|
|||
return annotation_layer
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
name = self._properties.get("name", "")
|
||||
self._model = AnnotationLayerDAO.find_by_id(self._model_id)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO):
|
|||
|
||||
@staticmethod
|
||||
def bulk_delete(
|
||||
models: Optional[List[AnnotationLayer]], commit: bool = True
|
||||
models: Optional[list[AnnotationLayer]], commit: bool = True
|
||||
) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
try:
|
||||
|
@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO):
|
|||
raise DAODeleteFailedError() from ex
|
||||
|
||||
@staticmethod
|
||||
def has_annotations(model_id: Union[int, List[int]]) -> bool:
|
||||
def has_annotations(model_id: Union[int, list[int]]) -> bool:
|
||||
if isinstance(model_id, list):
|
||||
return (
|
||||
db.session.query(AnnotationLayer)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
|
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BulkDeleteChartCommand(BaseCommand):
|
||||
def __init__(self, model_ids: List[int]):
|
||||
def __init__(self, model_ids: list[int]):
|
||||
self._model_ids = model_ids
|
||||
self._models: Optional[List[Slice]] = None
|
||||
self._models: Optional[list[Slice]] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateChartCommand(CreateMixin, BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
|
@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
|
|||
datasource_type = self._properties["datasource_type"]
|
||||
datasource_id = self._properties["datasource_id"]
|
||||
dashboard_ids = self._properties.get("dashboards", [])
|
||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
|
||||
# Validate/Populate datasource
|
||||
try:
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from typing import Iterator, Tuple
|
||||
from collections.abc import Iterator
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand):
|
|||
not_found = ChartNotFoundError
|
||||
|
||||
@staticmethod
|
||||
def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]:
|
||||
def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
|
||||
file_name = get_filename(model.slice_name, model.id)
|
||||
file_path = f"charts/{file_name}.yaml"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
|
@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand):
|
|||
until it finds one that matches.
|
||||
"""
|
||||
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any
|
||||
|
||||
from marshmallow import Schema
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand):
|
|||
dao = ChartDAO
|
||||
model_name = "chart"
|
||||
prefix = "charts/"
|
||||
schemas: Dict[str, Schema] = {
|
||||
schemas: dict[str, Schema] = {
|
||||
"charts/": ImportV1ChartSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
"databases/": ImportV1DatabaseSchema(),
|
||||
|
@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand):
|
|||
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
||||
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||
) -> None:
|
||||
# discover datasets associated with charts
|
||||
dataset_uuids: Set[str] = set()
|
||||
dataset_uuids: set[str] = set()
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("charts/"):
|
||||
dataset_uuids.add(config["dataset_uuid"])
|
||||
|
||||
# discover databases associated with datasets
|
||||
database_uuids: Set[str] = set()
|
||||
database_uuids: set[str] = set()
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
||||
database_uuids.add(config["database_uuid"])
|
||||
|
||||
# import related databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
||||
database = import_database(session, config, overwrite=False)
|
||||
database_ids[str(database.uuid)] = database.id
|
||||
|
||||
# import datasets with the correct parent ref
|
||||
datasets: Dict[str, SqlaTable] = {}
|
||||
datasets: dict[str, SqlaTable] = {}
|
||||
for file_name, config in configs.items():
|
||||
if (
|
||||
file_name.startswith("datasets/")
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask import g
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -28,7 +28,7 @@ from superset.models.slice import Slice
|
|||
|
||||
def import_chart(
|
||||
session: Session,
|
||||
config: Dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
ignore_permissions: bool = False,
|
||||
) -> Slice:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
@ -42,14 +42,14 @@ from superset.models.slice import Slice
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_query_context_update(properties: Dict[str, Any]) -> bool:
|
||||
def is_query_context_update(properties: dict[str, Any]) -> bool:
|
||||
return set(properties) == {"query_context", "query_context_generation"} and bool(
|
||||
properties.get("query_context_generation")
|
||||
)
|
||||
|
||||
|
||||
class UpdateChartCommand(UpdateMixin, BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Slice] = None
|
||||
|
@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
|
|||
return chart
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
dashboard_ids = self._properties.get("dashboards")
|
||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
|
||||
# Validate if datasource_id is provided datasource_type is required
|
||||
datasource_id = self._properties.get("datasource_id")
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# pylint: disable=arguments-renamed
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
@ -39,7 +39,7 @@ class ChartDAO(BaseDAO):
|
|||
base_filter = ChartFilter
|
||||
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
|
||||
def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
# bulk delete, first delete related data
|
||||
if models:
|
||||
|
@ -71,7 +71,7 @@ class ChartDAO(BaseDAO):
|
|||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
|
||||
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
|
||||
ids = [chart.id for chart in charts]
|
||||
return [
|
||||
star.obj_id
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import simplejson
|
||||
from flask import current_app, g, make_response, request, Response
|
||||
|
@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi):
|
|||
return self._get_data_response(command, True)
|
||||
|
||||
def _run_async(
|
||||
self, form_data: Dict[str, Any], command: ChartDataCommand
|
||||
self, form_data: dict[str, Any], command: ChartDataCommand
|
||||
) -> Response:
|
||||
"""
|
||||
Execute command as an async query.
|
||||
|
@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi):
|
|||
|
||||
def _send_chart_response(
|
||||
self,
|
||||
result: Dict[Any, Any],
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
datasource: Optional[Union[BaseDatasource, Query]] = None,
|
||||
result: dict[Any, Any],
|
||||
form_data: dict[str, Any] | None = None,
|
||||
datasource: BaseDatasource | Query | None = None,
|
||||
) -> Response:
|
||||
result_type = result["query_context"].result_type
|
||||
result_format = result["query_context"].result_format
|
||||
|
@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi):
|
|||
self,
|
||||
command: ChartDataCommand,
|
||||
force_cached: bool = False,
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
datasource: Optional[Union[BaseDatasource, Query]] = None,
|
||||
form_data: dict[str, Any] | None = None,
|
||||
datasource: BaseDatasource | Query | None = None,
|
||||
) -> Response:
|
||||
try:
|
||||
result = command.run(force_cached=force_cached)
|
||||
|
@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi):
|
|||
return self._send_chart_response(result, form_data, datasource)
|
||||
|
||||
# pylint: disable=invalid-name, no-self-use
|
||||
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
|
||||
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
|
||||
return QueryContextCacheLoader.load(cache_key)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _create_query_context_from_form(
|
||||
self, form_data: Dict[str, Any]
|
||||
self, form_data: dict[str, Any]
|
||||
) -> QueryContext:
|
||||
try:
|
||||
return ChartDataQueryContextSchema().load(form_data)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
|
@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
|
|||
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
||||
self._async_channel_id = jwt_data["channel"]
|
||||
|
||||
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
|
||||
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
|
||||
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
|
||||
load_chart_data_into_cache.delay(job_metadata, form_data)
|
||||
return job_metadata
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
|
@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand):
|
|||
def __init__(self, query_context: QueryContext):
|
||||
self._query_context = query_context
|
||||
|
||||
def run(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
def run(self, **kwargs: Any) -> dict[str, Any]:
|
||||
# caching is handled in query_context.get_df_payload
|
||||
# (also evals `force` property)
|
||||
cache_query_context = kwargs.get("cache", False)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from superset import cache
|
||||
from superset.charts.commands.exceptions import ChartDataCacheLoadError
|
||||
|
@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError
|
|||
|
||||
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
|
||||
@staticmethod
|
||||
def load(cache_key: str) -> Dict[str, Any]:
|
||||
def load(cache_key: str) -> dict[str, Any]:
|
||||
cache_value = cache.get(cache_key)
|
||||
if not cache_value:
|
||||
raise ChartDataCacheLoadError("Cached data not found")
|
||||
|
|
|
@ -27,7 +27,7 @@ for these chart types.
|
|||
"""
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as __
|
||||
|
@ -45,14 +45,14 @@ if TYPE_CHECKING:
|
|||
from superset.models.sql_lab import Query
|
||||
|
||||
|
||||
def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
|
||||
def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]:
|
||||
"""
|
||||
Sort columns when combining metrics.
|
||||
|
||||
MultiIndex labels have the metric name as the last element in the
|
||||
tuple. We want to sort these according to the list of passed metrics.
|
||||
"""
|
||||
parts: List[Any] = list(label)
|
||||
parts: list[Any] = list(label)
|
||||
metric = parts[-1]
|
||||
parts[-1] = metrics.index(metric)
|
||||
return tuple(parts)
|
||||
|
@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...
|
|||
|
||||
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
|
||||
df: pd.DataFrame,
|
||||
rows: List[str],
|
||||
columns: List[str],
|
||||
metrics: List[str],
|
||||
rows: list[str],
|
||||
columns: list[str],
|
||||
metrics: list[str],
|
||||
aggfunc: str = "Sum",
|
||||
transpose_pivot: bool = False,
|
||||
combine_metrics: bool = False,
|
||||
|
@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str:
|
|||
"""
|
||||
List unique values in a series.
|
||||
"""
|
||||
return ", ".join(set(str(v) for v in pd.Series.unique(series)))
|
||||
return ", ".join({str(v) for v in pd.Series.unique(series)})
|
||||
|
||||
|
||||
pivot_v2_aggfunc_map = {
|
||||
|
@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = {
|
|||
|
||||
def pivot_table_v2(
|
||||
df: pd.DataFrame,
|
||||
form_data: Dict[str, Any],
|
||||
form_data: dict[str, Any],
|
||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
|
@ -249,7 +249,7 @@ def pivot_table_v2(
|
|||
|
||||
def pivot_table(
|
||||
df: pd.DataFrame,
|
||||
form_data: Dict[str, Any],
|
||||
form_data: dict[str, Any],
|
||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
|
@ -285,7 +285,7 @@ def pivot_table(
|
|||
|
||||
def table(
|
||||
df: pd.DataFrame,
|
||||
form_data: Dict[str, Any],
|
||||
form_data: dict[str, Any],
|
||||
datasource: Optional[ # pylint: disable=unused-argument
|
||||
Union["BaseDatasource", "Query"]
|
||||
] = None,
|
||||
|
@ -315,10 +315,10 @@ post_processors = {
|
|||
|
||||
|
||||
def apply_post_process(
|
||||
result: Dict[Any, Any],
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
result: dict[Any, Any],
|
||||
form_data: Optional[dict[str, Any]] = None,
|
||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||
) -> Dict[Any, Any]:
|
||||
) -> dict[Any, Any]:
|
||||
form_data = form_data or {}
|
||||
|
||||
viz_type = form_data.get("viz_type")
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||
|
@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
|
||||
|
||||
class ChartDataQueryContextSchema(Schema):
|
||||
query_context_factory: Optional[QueryContextFactory] = None
|
||||
query_context_factory: QueryContextFactory | None = None
|
||||
datasource = fields.Nested(ChartDataDatasourceSchema)
|
||||
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
|
||||
custom_cache_timeout = fields.Integer(
|
||||
|
@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema):
|
|||
|
||||
# pylint: disable=unused-argument
|
||||
@post_load
|
||||
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
|
||||
def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext:
|
||||
query_context = self.get_query_context_factory().create(**data)
|
||||
return query_context
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
|||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
import click
|
||||
|
@ -309,7 +309,7 @@ else:
|
|||
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
|
||||
|
||||
path_object = Path(path)
|
||||
files: List[Path] = []
|
||||
files: list[Path] = []
|
||||
if path_object.is_file():
|
||||
files.append(path_object)
|
||||
elif path_object.exists() and not recursive:
|
||||
|
@ -363,7 +363,7 @@ else:
|
|||
sync_metrics = "metrics" in sync_array
|
||||
|
||||
path_object = Path(path)
|
||||
files: List[Path] = []
|
||||
files: list[Path] = []
|
||||
if path_object.is_file():
|
||||
files.append(path_object)
|
||||
elif path_object.exists() and not recursive:
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from colorama import Fore, Style
|
||||
|
@ -40,7 +40,7 @@ def superset() -> None:
|
|||
"""This is a management script for the Superset application."""
|
||||
|
||||
@app.shell_context_processor
|
||||
def make_shell_context() -> Dict[str, Any]:
|
||||
def make_shell_context() -> dict[str, Any]:
|
||||
return dict(app=app, db=db)
|
||||
|
||||
|
||||
|
@ -79,5 +79,5 @@ def version(verbose: bool) -> None:
|
|||
)
|
||||
print(Fore.BLUE + "-=" * 15)
|
||||
if verbose:
|
||||
print("[DB] : " + "{}".format(db.engine))
|
||||
print("[DB] : " + f"{db.engine}")
|
||||
print(Style.RESET_ALL)
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from textwrap import dedent
|
||||
from typing import Set, Tuple
|
||||
|
||||
import click
|
||||
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
|
||||
|
@ -102,7 +101,7 @@ def native_filters() -> None:
|
|||
)
|
||||
def upgrade(
|
||||
all_: bool, # pylint: disable=unused-argument
|
||||
dashboard_ids: Tuple[int, ...],
|
||||
dashboard_ids: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Upgrade legacy filter-box charts to native dashboard filters.
|
||||
|
@ -251,7 +250,7 @@ def upgrade(
|
|||
)
|
||||
def downgrade(
|
||||
all_: bool, # pylint: disable=unused-argument
|
||||
dashboard_ids: Tuple[int, ...],
|
||||
dashboard_ids: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
|
||||
|
@ -347,7 +346,7 @@ def downgrade(
|
|||
)
|
||||
def cleanup(
|
||||
all_: bool, # pylint: disable=unused-argument
|
||||
dashboard_ids: Tuple[int, ...],
|
||||
dashboard_ids: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Cleanup obsolete legacy filter-box charts and interim metadata.
|
||||
|
@ -355,7 +354,7 @@ def cleanup(
|
|||
Note this operation is irreversible.
|
||||
"""
|
||||
|
||||
slice_ids: Set[int] = set()
|
||||
slice_ids: set[int] = set()
|
||||
|
||||
# Cleanup the dashboard which contains legacy fields used for downgrading.
|
||||
for dashboard in (
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Type, Union
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
from celery.utils.abstract import CallableTask
|
||||
|
@ -75,7 +75,7 @@ def compute_thumbnails(
|
|||
|
||||
def compute_generic_thumbnail(
|
||||
friendly_type: str,
|
||||
model_cls: Union[Type[Dashboard], Type[Slice]],
|
||||
model_cls: Union[type[Dashboard], type[Slice]],
|
||||
model_id: int,
|
||||
compute_func: CallableTask,
|
||||
) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
@ -45,7 +45,7 @@ class BaseCommand(ABC):
|
|||
|
||||
class CreateMixin: # pylint: disable=too-few-public-methods
|
||||
@staticmethod
|
||||
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
|
||||
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
|
||||
"""
|
||||
Populate list of owners, defaulting to the current user if `owner_ids` is
|
||||
undefined or empty. If current user is missing in `owner_ids`, current user
|
||||
|
@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
class UpdateMixin: # pylint: disable=too-few-public-methods
|
||||
@staticmethod
|
||||
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
|
||||
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
|
||||
"""
|
||||
Populate list of owners. If current user is missing in `owner_ids`, current user
|
||||
is added unless belonging to the Admin role.
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
@ -59,7 +59,7 @@ class CommandInvalidError(CommandException):
|
|||
def __init__(
|
||||
self,
|
||||
message: str = "",
|
||||
exceptions: Optional[List[ValidationError]] = None,
|
||||
exceptions: Optional[list[ValidationError]] = None,
|
||||
) -> None:
|
||||
self._exceptions = exceptions or []
|
||||
super().__init__(message)
|
||||
|
@ -67,14 +67,14 @@ class CommandInvalidError(CommandException):
|
|||
def append(self, exception: ValidationError) -> None:
|
||||
self._exceptions.append(exception)
|
||||
|
||||
def extend(self, exceptions: List[ValidationError]) -> None:
|
||||
def extend(self, exceptions: list[ValidationError]) -> None:
|
||||
self._exceptions.extend(exceptions)
|
||||
|
||||
def get_list_classnames(self) -> List[str]:
|
||||
def get_list_classnames(self) -> list[str]:
|
||||
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
|
||||
|
||||
def normalized_messages(self) -> Dict[Any, Any]:
|
||||
errors: Dict[Any, Any] = {}
|
||||
def normalized_messages(self) -> dict[Any, Any]:
|
||||
errors: dict[Any, Any] = {}
|
||||
for exception in self._exceptions:
|
||||
errors.update(exception.normalized_messages())
|
||||
return errors
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand):
|
|||
Command that exports all databases, datasets, charts, dashboards and saved queries.
|
||||
"""
|
||||
|
||||
def run(self) -> Iterator[Tuple[str, str]]:
|
||||
def run(self) -> Iterator[tuple[str, str]]:
|
||||
metadata = {
|
||||
"version": EXPORT_VERSION,
|
||||
"type": "assets",
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, List, Tuple, Type
|
||||
|
||||
import yaml
|
||||
from flask_appbuilder import Model
|
||||
|
@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml"
|
|||
|
||||
|
||||
class ExportModelsCommand(BaseCommand):
|
||||
dao: Type[BaseDAO] = BaseDAO
|
||||
not_found: Type[CommandException] = CommandException
|
||||
dao: type[BaseDAO] = BaseDAO
|
||||
not_found: type[CommandException] = CommandException
|
||||
|
||||
def __init__(self, model_ids: List[int], export_related: bool = True):
|
||||
def __init__(self, model_ids: list[int], export_related: bool = True):
|
||||
self.model_ids = model_ids
|
||||
self.export_related = export_related
|
||||
|
||||
# this will be set when calling validate()
|
||||
self._models: List[Model] = []
|
||||
self._models: list[Model] = []
|
||||
|
||||
@staticmethod
|
||||
def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]:
|
||||
def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]:
|
||||
raise NotImplementedError("Subclasses MUST implement _export")
|
||||
|
||||
def run(self) -> Iterator[Tuple[str, str]]:
|
||||
def run(self) -> Iterator[tuple[str, str]]:
|
||||
self.validate()
|
||||
|
||||
metadata = {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
from marshmallow import Schema, validate
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand):
|
|||
dao = BaseDAO
|
||||
model_name = "model"
|
||||
prefix = ""
|
||||
schemas: Dict[str, Schema] = {}
|
||||
schemas: dict[str, Schema] = {}
|
||||
import_error = CommandException
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
|
||||
self.ssh_tunnel_passwords: Dict[str, str] = (
|
||||
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
|
||||
self.ssh_tunnel_passwords: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_passwords") or {}
|
||||
)
|
||||
self.ssh_tunnel_private_keys: Dict[str, str] = (
|
||||
self.ssh_tunnel_private_keys: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_private_keys") or {}
|
||||
)
|
||||
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
|
||||
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
||||
)
|
||||
self.overwrite: bool = kwargs.get("overwrite", False)
|
||||
self._configs: Dict[str, Any] = {}
|
||||
self._configs: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
||||
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||
) -> None:
|
||||
raise NotImplementedError("Subclasses MUST implement _import")
|
||||
|
||||
@classmethod
|
||||
def _get_uuids(cls) -> Set[str]:
|
||||
def _get_uuids(cls) -> set[str]:
|
||||
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
|
||||
|
||||
def run(self) -> None:
|
||||
|
@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand):
|
|||
raise self.import_error() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
|
||||
# verify that the metadata file is present and valid
|
||||
try:
|
||||
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
|
||||
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
|
||||
except ValidationError as exc:
|
||||
exceptions.append(exc)
|
||||
metadata = None
|
||||
|
@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand):
|
|||
)
|
||||
|
||||
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
|
||||
self, exceptions: List[ValidationError]
|
||||
self, exceptions: list[ValidationError]
|
||||
) -> None:
|
||||
"""check if the object exists and shouldn't be overwritten"""
|
||||
if not self.overwrite:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from marshmallow import Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand):
|
|||
and will overwrite everything.
|
||||
"""
|
||||
|
||||
schemas: Dict[str, Schema] = {
|
||||
schemas: dict[str, Schema] = {
|
||||
"charts/": ImportV1ChartSchema(),
|
||||
"dashboards/": ImportV1DashboardSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
|
@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand):
|
|||
}
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
|
||||
self.ssh_tunnel_passwords: Dict[str, str] = (
|
||||
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
|
||||
self.ssh_tunnel_passwords: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_passwords") or {}
|
||||
)
|
||||
self.ssh_tunnel_private_keys: Dict[str, str] = (
|
||||
self.ssh_tunnel_private_keys: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_private_keys") or {}
|
||||
)
|
||||
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
|
||||
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
|
||||
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
||||
)
|
||||
self._configs: Dict[str, Any] = {}
|
||||
self._configs: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _import(session: Session, configs: Dict[str, Any]) -> None:
|
||||
def _import(session: Session, configs: dict[str, Any]) -> None:
|
||||
# import databases first
|
||||
database_ids: Dict[str, int] = {}
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/"):
|
||||
database = import_database(session, config, overwrite=True)
|
||||
|
@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand):
|
|||
import_saved_query(session, config, overwrite=True)
|
||||
|
||||
# import datasets
|
||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
||||
dataset_info: dict[str, dict[str, Any]] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("datasets/"):
|
||||
config["database_id"] = database_ids[config["database_uuid"]]
|
||||
|
@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand):
|
|||
}
|
||||
|
||||
# import charts
|
||||
chart_ids: Dict[str, int] = {}
|
||||
chart_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("charts/"):
|
||||
config.update(dataset_info[config["dataset_uuid"]])
|
||||
|
@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand):
|
|||
dashboard = import_dashboard(session, config, overwrite=True)
|
||||
|
||||
# set ref in the dashboard_slices table
|
||||
dashboard_chart_ids: List[Dict[str, int]] = []
|
||||
dashboard_chart_ids: list[dict[str, int]] = []
|
||||
for uuid in find_chart_uuids(config["position"]):
|
||||
if uuid not in chart_ids:
|
||||
break
|
||||
|
@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand):
|
|||
raise ImportFailedError() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
|
||||
# verify that the metadata file is present and valid
|
||||
try:
|
||||
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
|
||||
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
|
||||
except ValidationError as exc:
|
||||
exceptions.append(exc)
|
||||
metadata = None
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
from marshmallow import Schema
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
|
||||
dao = BaseDAO
|
||||
model_name = "model"
|
||||
schemas: Dict[str, Schema] = {
|
||||
schemas: dict[str, Schema] = {
|
||||
"charts/": ImportV1ChartSchema(),
|
||||
"dashboards/": ImportV1DashboardSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
|
@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
}
|
||||
import_error = CommandException
|
||||
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
super().__init__(contents, *args, **kwargs)
|
||||
self.force_data = kwargs.get("force_data", False)
|
||||
|
||||
|
@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
raise self.import_error() from ex
|
||||
|
||||
@classmethod
|
||||
def _get_uuids(cls) -> Set[str]:
|
||||
def _get_uuids(cls) -> set[str]:
|
||||
# pylint: disable=protected-access
|
||||
return (
|
||||
ImportDatabasesCommand._get_uuids()
|
||||
|
@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
@staticmethod
|
||||
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
|
||||
session: Session,
|
||||
configs: Dict[str, Any],
|
||||
configs: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
force_data: bool = False,
|
||||
) -> None:
|
||||
# import databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/"):
|
||||
database = import_database(
|
||||
|
@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
# database was created before its UUID was frozen, so it has a random UUID.
|
||||
# We need to determine its ID so we can point the dataset to it.
|
||||
examples_db = get_example_database()
|
||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
||||
dataset_info: dict[str, dict[str, Any]] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("datasets/"):
|
||||
# find the ID of the corresponding database
|
||||
|
@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
}
|
||||
|
||||
# import charts
|
||||
chart_ids: Dict[str, int] = {}
|
||||
chart_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if (
|
||||
file_name.startswith("charts/")
|
||||
|
@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
).fetchall()
|
||||
|
||||
# import dashboards
|
||||
dashboard_chart_ids: List[Tuple[int, int]] = []
|
||||
dashboard_chart_ids: list[tuple[int, int]] = []
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("dashboards/"):
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import yaml
|
||||
|
@ -46,7 +46,7 @@ class MetadataSchema(Schema):
|
|||
timestamp = fields.DateTime()
|
||||
|
||||
|
||||
def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
|
||||
def load_yaml(file_name: str, content: str) -> dict[str, Any]:
|
||||
"""Try to load a YAML file"""
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
|
@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
|
|||
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
|
||||
|
||||
|
||||
def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
|
||||
def load_metadata(contents: dict[str, str]) -> dict[str, str]:
|
||||
"""Apply validation and load a metadata file"""
|
||||
if METADATA_FILE_NAME not in contents:
|
||||
# if the contents have no METADATA_FILE_NAME this is probably
|
||||
|
@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
|
|||
|
||||
|
||||
def validate_metadata_type(
|
||||
metadata: Optional[Dict[str, str]],
|
||||
metadata: Optional[dict[str, str]],
|
||||
type_: str,
|
||||
exceptions: List[ValidationError],
|
||||
exceptions: list[ValidationError],
|
||||
) -> None:
|
||||
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
|
||||
if metadata and "type" in metadata:
|
||||
|
@ -96,35 +96,35 @@ def validate_metadata_type(
|
|||
|
||||
# pylint: disable=too-many-locals,too-many-arguments
|
||||
def load_configs(
|
||||
contents: Dict[str, str],
|
||||
schemas: Dict[str, Schema],
|
||||
passwords: Dict[str, str],
|
||||
exceptions: List[ValidationError],
|
||||
ssh_tunnel_passwords: Dict[str, str],
|
||||
ssh_tunnel_private_keys: Dict[str, str],
|
||||
ssh_tunnel_priv_key_passwords: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
configs: Dict[str, Any] = {}
|
||||
contents: dict[str, str],
|
||||
schemas: dict[str, Schema],
|
||||
passwords: dict[str, str],
|
||||
exceptions: list[ValidationError],
|
||||
ssh_tunnel_passwords: dict[str, str],
|
||||
ssh_tunnel_private_keys: dict[str, str],
|
||||
ssh_tunnel_priv_key_passwords: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
configs: dict[str, Any] = {}
|
||||
|
||||
# load existing databases so we can apply the password validation
|
||||
db_passwords: Dict[str, str] = {
|
||||
db_passwords: dict[str, str] = {
|
||||
str(uuid): password
|
||||
for uuid, password in db.session.query(Database.uuid, Database.password).all()
|
||||
}
|
||||
# load existing ssh_tunnels so we can apply the password validation
|
||||
db_ssh_tunnel_passwords: Dict[str, str] = {
|
||||
db_ssh_tunnel_passwords: dict[str, str] = {
|
||||
str(uuid): password
|
||||
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
|
||||
}
|
||||
# load existing ssh_tunnels so we can apply the private_key validation
|
||||
db_ssh_tunnel_private_keys: Dict[str, str] = {
|
||||
db_ssh_tunnel_private_keys: dict[str, str] = {
|
||||
str(uuid): private_key
|
||||
for uuid, private_key in db.session.query(
|
||||
SSHTunnel.uuid, SSHTunnel.private_key
|
||||
).all()
|
||||
}
|
||||
# load existing ssh_tunnels so we can apply the private_key_password validation
|
||||
db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
|
||||
db_ssh_tunnel_priv_key_passws: dict[str, str] = {
|
||||
str(uuid): private_key_password
|
||||
for uuid, private_key_password in db.session.query(
|
||||
SSHTunnel.uuid, SSHTunnel.private_key_password
|
||||
|
@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]:
|
||||
def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
|
||||
return {
|
||||
remove_root(file_name): bundle.read(file_name).decode()
|
||||
for file_name in bundle.namelist()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.security.sqla.models import Role, User
|
||||
|
@ -37,9 +37,9 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def populate_owners(
|
||||
owner_ids: Optional[List[int]],
|
||||
owner_ids: list[int] | None,
|
||||
default_to_user: bool,
|
||||
) -> List[User]:
|
||||
) -> list[User]:
|
||||
"""
|
||||
Helper function for commands, will fetch all users from owners id's
|
||||
|
||||
|
@ -63,13 +63,13 @@ def populate_owners(
|
|||
return owners
|
||||
|
||||
|
||||
def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
|
||||
def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
|
||||
"""
|
||||
Helper function for commands, will fetch all roles from roles id's
|
||||
:raises RolesNotFoundValidationError: If a role in the input list is not found
|
||||
:param role_ids: A List of roles by id's
|
||||
"""
|
||||
roles: List[Role] = []
|
||||
roles: list[Role] = []
|
||||
if role_ids:
|
||||
roles = security_manager.find_roles_by_id(role_ids)
|
||||
if len(roles) != len(role_ids):
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from enum import Enum
|
||||
from typing import Set
|
||||
|
||||
|
||||
class ChartDataResultFormat(str, Enum):
|
||||
|
@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum):
|
|||
XLSX = "xlsx"
|
||||
|
||||
@classmethod
|
||||
def table_like(cls) -> Set["ChartDataResultFormat"]:
|
||||
def table_like(cls) -> set["ChartDataResultFormat"]:
|
||||
return {cls.CSV} | {cls.XLSX}
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from flask_babel import _
|
||||
|
||||
|
@ -49,7 +49,7 @@ def _get_datasource(
|
|||
|
||||
def _get_columns(
|
||||
query_context: QueryContext, query_obj: QueryObject, _: bool
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
return {
|
||||
"data": [
|
||||
|
@ -65,7 +65,7 @@ def _get_columns(
|
|||
|
||||
def _get_timegrains(
|
||||
query_context: QueryContext, query_obj: QueryObject, _: bool
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
return {
|
||||
"data": [
|
||||
|
@ -83,7 +83,7 @@ def _get_query(
|
|||
query_context: QueryContext,
|
||||
query_obj: QueryObject,
|
||||
_: bool,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
result = {"language": datasource.query_language}
|
||||
try:
|
||||
|
@ -96,8 +96,8 @@ def _get_query(
|
|||
def _get_full(
|
||||
query_context: QueryContext,
|
||||
query_obj: QueryObject,
|
||||
force_cached: Optional[bool] = False,
|
||||
) -> Dict[str, Any]:
|
||||
force_cached: bool | None = False,
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
result_type = query_obj.result_type or query_context.result_type
|
||||
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
|
||||
|
@ -141,7 +141,7 @@ def _get_full(
|
|||
|
||||
def _get_samples(
|
||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
query_obj = copy.copy(query_obj)
|
||||
query_obj.is_timeseries = False
|
||||
|
@ -162,7 +162,7 @@ def _get_samples(
|
|||
|
||||
def _get_drill_detail(
|
||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
# todo(yongjie): Remove this function,
|
||||
# when determining whether samples should be applied to the time filter.
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
|
@ -183,13 +183,13 @@ def _get_drill_detail(
|
|||
|
||||
def _get_results(
|
||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
payload = _get_full(query_context, query_obj, force_cached)
|
||||
return payload
|
||||
|
||||
|
||||
_result_type_functions: Dict[
|
||||
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]]
|
||||
_result_type_functions: dict[
|
||||
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
|
||||
] = {
|
||||
ChartDataResultType.COLUMNS: _get_columns,
|
||||
ChartDataResultType.TIMEGRAINS: _get_timegrains,
|
||||
|
@ -210,7 +210,7 @@ def get_query_results(
|
|||
query_context: QueryContext,
|
||||
query_obj: QueryObject,
|
||||
force_cached: bool,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Return result payload for a chart data request.
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, ClassVar, TYPE_CHECKING
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
@ -47,15 +47,15 @@ class QueryContext:
|
|||
enforce_numerical_metrics: ClassVar[bool] = True
|
||||
|
||||
datasource: BaseDatasource
|
||||
slice_: Optional[Slice] = None
|
||||
queries: List[QueryObject]
|
||||
form_data: Optional[Dict[str, Any]]
|
||||
slice_: Slice | None = None
|
||||
queries: list[QueryObject]
|
||||
form_data: dict[str, Any] | None
|
||||
result_type: ChartDataResultType
|
||||
result_format: ChartDataResultFormat
|
||||
force: bool
|
||||
custom_cache_timeout: Optional[int]
|
||||
custom_cache_timeout: int | None
|
||||
|
||||
cache_values: Dict[str, Any]
|
||||
cache_values: dict[str, Any]
|
||||
|
||||
_processor: QueryContextProcessor
|
||||
|
||||
|
@ -65,14 +65,14 @@ class QueryContext:
|
|||
self,
|
||||
*,
|
||||
datasource: BaseDatasource,
|
||||
queries: List[QueryObject],
|
||||
slice_: Optional[Slice],
|
||||
form_data: Optional[Dict[str, Any]],
|
||||
queries: list[QueryObject],
|
||||
slice_: Slice | None,
|
||||
form_data: dict[str, Any] | None,
|
||||
result_type: ChartDataResultType,
|
||||
result_format: ChartDataResultFormat,
|
||||
force: bool = False,
|
||||
custom_cache_timeout: Optional[int] = None,
|
||||
cache_values: Dict[str, Any],
|
||||
custom_cache_timeout: int | None = None,
|
||||
cache_values: dict[str, Any],
|
||||
) -> None:
|
||||
self.datasource = datasource
|
||||
self.slice_ = slice_
|
||||
|
@ -88,18 +88,18 @@ class QueryContext:
|
|||
def get_data(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
) -> Union[str, List[Dict[str, Any]]]:
|
||||
) -> str | list[dict[str, Any]]:
|
||||
return self._processor.get_data(df)
|
||||
|
||||
def get_payload(
|
||||
self,
|
||||
cache_query_context: Optional[bool] = False,
|
||||
cache_query_context: bool | None = False,
|
||||
force_cached: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Returns the query results with both metadata and data"""
|
||||
return self._processor.get_payload(cache_query_context, force_cached)
|
||||
|
||||
def get_cache_timeout(self) -> Optional[int]:
|
||||
def get_cache_timeout(self) -> int | None:
|
||||
if self.custom_cache_timeout is not None:
|
||||
return self.custom_cache_timeout
|
||||
if self.slice_ and self.slice_.cache_timeout is not None:
|
||||
|
@ -110,14 +110,14 @@ class QueryContext:
|
|||
return self.datasource.database.cache_timeout
|
||||
return None
|
||||
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
|
||||
return self._processor.query_cache_key(query_obj, **kwargs)
|
||||
|
||||
def get_df_payload(
|
||||
self,
|
||||
query_obj: QueryObject,
|
||||
force_cached: Optional[bool] = False,
|
||||
) -> Dict[str, Any]:
|
||||
force_cached: bool | None = False,
|
||||
) -> dict[str, Any]:
|
||||
return self._processor.get_df_payload(
|
||||
query_obj=query_obj,
|
||||
force_cached=force_cached,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from superset import app, db
|
||||
from superset.charts.dao import ChartDAO
|
||||
|
@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
self,
|
||||
*,
|
||||
datasource: DatasourceDict,
|
||||
queries: List[Dict[str, Any]],
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
result_type: Optional[ChartDataResultType] = None,
|
||||
result_format: Optional[ChartDataResultFormat] = None,
|
||||
queries: list[dict[str, Any]],
|
||||
form_data: dict[str, Any] | None = None,
|
||||
result_type: ChartDataResultType | None = None,
|
||||
result_format: ChartDataResultFormat | None = None,
|
||||
force: bool = False,
|
||||
custom_cache_timeout: Optional[int] = None,
|
||||
custom_cache_timeout: int | None = None,
|
||||
) -> QueryContext:
|
||||
datasource_model_instance = None
|
||||
if datasource:
|
||||
|
@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
datasource_id=int(datasource["id"]),
|
||||
)
|
||||
|
||||
def _get_slice(self, slice_id: Any) -> Optional[Slice]:
|
||||
def _get_slice(self, slice_id: Any) -> Slice | None:
|
||||
return ChartDAO.find_by_id(slice_id)
|
||||
|
||||
def _process_query_object(
|
||||
self,
|
||||
datasource: BaseDatasource,
|
||||
form_data: Optional[Dict[str, Any]],
|
||||
form_data: dict[str, Any] | None,
|
||||
query_object: QueryObject,
|
||||
) -> QueryObject:
|
||||
self._apply_granularity(query_object, form_data, datasource)
|
||||
|
@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
def _apply_granularity(
|
||||
self,
|
||||
query_object: QueryObject,
|
||||
form_data: Optional[Dict[str, Any]],
|
||||
form_data: dict[str, Any] | None,
|
||||
datasource: BaseDatasource,
|
||||
) -> None:
|
||||
temporal_columns = {
|
||||
|
|
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
|||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, ClassVar, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -77,8 +77,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class CachedTimeOffset(TypedDict):
|
||||
df: pd.DataFrame
|
||||
queries: List[str]
|
||||
cache_keys: List[Optional[str]]
|
||||
queries: list[str]
|
||||
cache_keys: list[str | None]
|
||||
|
||||
|
||||
class QueryContextProcessor:
|
||||
|
@ -102,8 +102,8 @@ class QueryContextProcessor:
|
|||
enforce_numerical_metrics: ClassVar[bool] = True
|
||||
|
||||
def get_df_payload(
|
||||
self, query_obj: QueryObject, force_cached: Optional[bool] = False
|
||||
) -> Dict[str, Any]:
|
||||
self, query_obj: QueryObject, force_cached: bool | None = False
|
||||
) -> dict[str, Any]:
|
||||
"""Handles caching around the df payload retrieval"""
|
||||
cache_key = self.query_cache_key(query_obj)
|
||||
timeout = self.get_cache_timeout()
|
||||
|
@ -181,7 +181,7 @@ class QueryContextProcessor:
|
|||
"label_map": label_map,
|
||||
}
|
||||
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
|
||||
"""
|
||||
Returns a QueryObject cache key for objects in self.queries
|
||||
"""
|
||||
|
@ -248,8 +248,8 @@ class QueryContextProcessor:
|
|||
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
|
||||
# todo: should support "python_date_format" and "get_column" in each datasource
|
||||
def _get_timestamp_format(
|
||||
source: BaseDatasource, column: Optional[str]
|
||||
) -> Optional[str]:
|
||||
source: BaseDatasource, column: str | None
|
||||
) -> str | None:
|
||||
column_obj = source.get_column(column)
|
||||
if (
|
||||
column_obj
|
||||
|
@ -315,9 +315,9 @@ class QueryContextProcessor:
|
|||
query_context = self._query_context
|
||||
# ensure query_object is immutable
|
||||
query_object_clone = copy.copy(query_object)
|
||||
queries: List[str] = []
|
||||
cache_keys: List[Optional[str]] = []
|
||||
rv_dfs: List[pd.DataFrame] = [df]
|
||||
queries: list[str] = []
|
||||
cache_keys: list[str | None] = []
|
||||
rv_dfs: list[pd.DataFrame] = [df]
|
||||
|
||||
time_offsets = query_object.time_offsets
|
||||
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
|
||||
|
@ -449,7 +449,7 @@ class QueryContextProcessor:
|
|||
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
|
||||
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
|
||||
|
||||
def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
|
||||
def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
|
||||
if self._query_context.result_format in ChartDataResultFormat.table_like():
|
||||
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||
columns = list(df.columns)
|
||||
|
@ -470,9 +470,9 @@ class QueryContextProcessor:
|
|||
|
||||
def get_payload(
|
||||
self,
|
||||
cache_query_context: Optional[bool] = False,
|
||||
cache_query_context: bool | None = False,
|
||||
force_cached: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Returns the query results with both metadata and data"""
|
||||
|
||||
# Get all the payloads from the QueryObjects
|
||||
|
@ -522,13 +522,13 @@ class QueryContextProcessor:
|
|||
|
||||
return generate_cache_key(cache_dict, key_prefix)
|
||||
|
||||
def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
|
||||
def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]:
|
||||
"""
|
||||
:param query_context:
|
||||
:param query_obj:
|
||||
:return:
|
||||
"""
|
||||
annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
|
||||
annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj)
|
||||
for annotation_layer in [
|
||||
layer
|
||||
for layer in query_obj.annotation_layers
|
||||
|
@ -541,7 +541,7 @@ class QueryContextProcessor:
|
|||
return annotation_data
|
||||
|
||||
@staticmethod
|
||||
def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
|
||||
def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
|
||||
annotation_data = {}
|
||||
annotation_layers = [
|
||||
layer
|
||||
|
@ -576,8 +576,8 @@ class QueryContextProcessor:
|
|||
|
||||
@staticmethod
|
||||
def get_viz_annotation_data(
|
||||
annotation_layer: Dict[str, Any], force: bool
|
||||
) -> Dict[str, Any]:
|
||||
annotation_layer: dict[str, Any], force: bool
|
||||
) -> dict[str, Any]:
|
||||
chart = ChartDAO.find_by_id(annotation_layer["value"])
|
||||
if not chart:
|
||||
raise QueryObjectValidationError(_("The chart does not exist"))
|
||||
|
|
|
@ -21,7 +21,7 @@ import json
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from pprint import pformat
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
from flask_babel import gettext as _
|
||||
|
@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
and druid. The query objects are constructed on the client.
|
||||
"""
|
||||
|
||||
annotation_layers: List[Dict[str, Any]]
|
||||
applied_time_extras: Dict[str, str]
|
||||
annotation_layers: list[dict[str, Any]]
|
||||
applied_time_extras: dict[str, str]
|
||||
apply_fetch_values_predicate: bool
|
||||
columns: List[Column]
|
||||
datasource: Optional[BaseDatasource]
|
||||
extras: Dict[str, Any]
|
||||
filter: List[QueryObjectFilterClause]
|
||||
from_dttm: Optional[datetime]
|
||||
granularity: Optional[str]
|
||||
inner_from_dttm: Optional[datetime]
|
||||
inner_to_dttm: Optional[datetime]
|
||||
columns: list[Column]
|
||||
datasource: BaseDatasource | None
|
||||
extras: dict[str, Any]
|
||||
filter: list[QueryObjectFilterClause]
|
||||
from_dttm: datetime | None
|
||||
granularity: str | None
|
||||
inner_from_dttm: datetime | None
|
||||
inner_to_dttm: datetime | None
|
||||
is_rowcount: bool
|
||||
is_timeseries: bool
|
||||
metrics: Optional[List[Metric]]
|
||||
metrics: list[Metric] | None
|
||||
order_desc: bool
|
||||
orderby: List[OrderBy]
|
||||
post_processing: List[Dict[str, Any]]
|
||||
result_type: Optional[ChartDataResultType]
|
||||
row_limit: Optional[int]
|
||||
orderby: list[OrderBy]
|
||||
post_processing: list[dict[str, Any]]
|
||||
result_type: ChartDataResultType | None
|
||||
row_limit: int | None
|
||||
row_offset: int
|
||||
series_columns: List[Column]
|
||||
series_columns: list[Column]
|
||||
series_limit: int
|
||||
series_limit_metric: Optional[Metric]
|
||||
time_offsets: List[str]
|
||||
time_shift: Optional[str]
|
||||
time_range: Optional[str]
|
||||
to_dttm: Optional[datetime]
|
||||
series_limit_metric: Metric | None
|
||||
time_offsets: list[str]
|
||||
time_shift: str | None
|
||||
time_range: str | None
|
||||
to_dttm: datetime | None
|
||||
|
||||
def __init__( # pylint: disable=too-many-locals
|
||||
self,
|
||||
*,
|
||||
annotation_layers: Optional[List[Dict[str, Any]]] = None,
|
||||
applied_time_extras: Optional[Dict[str, str]] = None,
|
||||
annotation_layers: list[dict[str, Any]] | None = None,
|
||||
applied_time_extras: dict[str, str] | None = None,
|
||||
apply_fetch_values_predicate: bool = False,
|
||||
columns: Optional[List[Column]] = None,
|
||||
datasource: Optional[BaseDatasource] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
filters: Optional[List[QueryObjectFilterClause]] = None,
|
||||
granularity: Optional[str] = None,
|
||||
columns: list[Column] | None = None,
|
||||
datasource: BaseDatasource | None = None,
|
||||
extras: dict[str, Any] | None = None,
|
||||
filters: list[QueryObjectFilterClause] | None = None,
|
||||
granularity: str | None = None,
|
||||
is_rowcount: bool = False,
|
||||
is_timeseries: Optional[bool] = None,
|
||||
metrics: Optional[List[Metric]] = None,
|
||||
is_timeseries: bool | None = None,
|
||||
metrics: list[Metric] | None = None,
|
||||
order_desc: bool = True,
|
||||
orderby: Optional[List[OrderBy]] = None,
|
||||
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
|
||||
row_limit: Optional[int],
|
||||
row_offset: Optional[int] = None,
|
||||
series_columns: Optional[List[Column]] = None,
|
||||
orderby: list[OrderBy] | None = None,
|
||||
post_processing: list[dict[str, Any] | None] | None = None,
|
||||
row_limit: int | None,
|
||||
row_offset: int | None = None,
|
||||
series_columns: list[Column] | None = None,
|
||||
series_limit: int = 0,
|
||||
series_limit_metric: Optional[Metric] = None,
|
||||
time_range: Optional[str] = None,
|
||||
time_shift: Optional[str] = None,
|
||||
series_limit_metric: Metric | None = None,
|
||||
time_range: str | None = None,
|
||||
time_shift: str | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._set_annotation_layers(annotation_layers)
|
||||
|
@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
self._move_deprecated_extra_fields(kwargs)
|
||||
|
||||
def _set_annotation_layers(
|
||||
self, annotation_layers: Optional[List[Dict[str, Any]]]
|
||||
self, annotation_layers: list[dict[str, Any]] | None
|
||||
) -> None:
|
||||
self.annotation_layers = [
|
||||
layer
|
||||
|
@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
if layer["annotationType"] != "FORMULA"
|
||||
]
|
||||
|
||||
def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
|
||||
def _set_is_timeseries(self, is_timeseries: bool | None) -> None:
|
||||
# is_timeseries is True if time column is in either columns or groupby
|
||||
# (both are dimensions)
|
||||
self.is_timeseries = (
|
||||
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
|
||||
)
|
||||
|
||||
def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
|
||||
def _set_metrics(self, metrics: list[Metric] | None = None) -> None:
|
||||
# Support metric reference/definition in the format of
|
||||
# 1. 'metric_name' - name of predefined metric
|
||||
# 2. { label: 'label_name' } - legacy format for a predefined metric
|
||||
|
@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
]
|
||||
|
||||
def _set_post_processing(
|
||||
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
|
||||
self, post_processing: list[dict[str, Any] | None] | None
|
||||
) -> None:
|
||||
post_processing = post_processing or []
|
||||
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
|
||||
|
||||
def _init_series_columns(
|
||||
self,
|
||||
series_columns: Optional[List[Column]],
|
||||
metrics: Optional[List[Metric]],
|
||||
is_timeseries: Optional[bool],
|
||||
series_columns: list[Column] | None,
|
||||
metrics: list[Metric] | None,
|
||||
is_timeseries: bool | None,
|
||||
) -> None:
|
||||
if series_columns:
|
||||
self.series_columns = series_columns
|
||||
|
@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
else:
|
||||
self.series_columns = []
|
||||
|
||||
def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
|
||||
def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None:
|
||||
# rename deprecated fields
|
||||
for field in DEPRECATED_FIELDS:
|
||||
if field.old_name in kwargs:
|
||||
|
@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
)
|
||||
setattr(self, field.new_name, value)
|
||||
|
||||
def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
|
||||
def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None:
|
||||
# move deprecated extras fields to extras
|
||||
for field in DEPRECATED_EXTRAS_FIELDS:
|
||||
if field.old_name in kwargs:
|
||||
|
@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
self.extras[field.new_name] = value
|
||||
|
||||
@property
|
||||
def metric_names(self) -> List[str]:
|
||||
def metric_names(self) -> list[str]:
|
||||
"""Return metrics names (labels), coerce adhoc metrics to strings."""
|
||||
return get_metric_names(self.metrics or [])
|
||||
|
||||
@property
|
||||
def column_names(self) -> List[str]:
|
||||
def column_names(self) -> list[str]:
|
||||
"""Return column names (labels). Gives priority to groupbys if both groupbys
|
||||
and metrics are non-empty, otherwise returns column labels."""
|
||||
return get_column_names(self.columns)
|
||||
|
||||
def validate(
|
||||
self, raise_exceptions: Optional[bool] = True
|
||||
) -> Optional[QueryObjectValidationError]:
|
||||
self, raise_exceptions: bool | None = True
|
||||
) -> QueryObjectValidationError | None:
|
||||
"""Validate query object"""
|
||||
try:
|
||||
self._validate_there_are_no_missing_series()
|
||||
|
@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
)
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
query_object_dict = {
|
||||
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
|
||||
"columns": self.columns,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_object import QueryObject
|
||||
|
@ -31,13 +31,13 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||
_config: Dict[str, Any]
|
||||
_config: dict[str, Any]
|
||||
_datasource_dao: DatasourceDAO
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_configurations: Dict[str, Any],
|
||||
app_configurations: dict[str, Any],
|
||||
_datasource_dao: DatasourceDAO,
|
||||
session_maker: sessionmaker,
|
||||
):
|
||||
|
@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
def create( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
parent_result_type: ChartDataResultType,
|
||||
datasource: Optional[DatasourceDict] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
row_limit: Optional[int] = None,
|
||||
time_range: Optional[str] = None,
|
||||
time_shift: Optional[str] = None,
|
||||
datasource: DatasourceDict | None = None,
|
||||
extras: dict[str, Any] | None = None,
|
||||
row_limit: int | None = None,
|
||||
time_range: str | None = None,
|
||||
time_shift: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> QueryObject:
|
||||
datasource_model_instance = None
|
||||
|
@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
|
||||
def _process_extras( # pylint: disable=no-self-use
|
||||
self,
|
||||
extras: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
extras: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
extras = extras or {}
|
||||
return extras
|
||||
|
||||
def _process_row_limit(
|
||||
self, row_limit: Optional[int], result_type: ChartDataResultType
|
||||
self, row_limit: int | None, result_type: ChartDataResultType
|
||||
) -> int:
|
||||
default_row_limit = (
|
||||
self._config["SAMPLES_ROW_LIMIT"]
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes
|
|||
|
||||
|
||||
def add_types_to_charts(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
slices = metadata.tables["slices"]
|
||||
|
||||
|
@ -57,7 +57,7 @@ def add_types_to_charts(
|
|||
|
||||
|
||||
def add_types_to_dashboards(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
dashboard_table = metadata.tables["dashboards"]
|
||||
|
||||
|
@ -89,7 +89,7 @@ def add_types_to_dashboards(
|
|||
|
||||
|
||||
def add_types_to_saved_queries(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
saved_query = metadata.tables["saved_query"]
|
||||
|
||||
|
@ -121,7 +121,7 @@ def add_types_to_saved_queries(
|
|||
|
||||
|
||||
def add_types_to_datasets(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
tables = metadata.tables["tables"]
|
||||
|
||||
|
@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None:
|
|||
|
||||
|
||||
def add_owners_to_charts(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
slices = metadata.tables["slices"]
|
||||
|
||||
|
@ -273,7 +273,7 @@ def add_owners_to_charts(
|
|||
|
||||
|
||||
def add_owners_to_dashboards(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
dashboard_table = metadata.tables["dashboards"]
|
||||
|
||||
|
@ -309,7 +309,7 @@ def add_owners_to_dashboards(
|
|||
|
||||
|
||||
def add_owners_to_saved_queries(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
saved_query = metadata.tables["saved_query"]
|
||||
|
||||
|
@ -345,7 +345,7 @@ def add_owners_to_saved_queries(
|
|||
|
||||
|
||||
def add_owners_to_datasets(
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||
) -> None:
|
||||
tables = metadata.tables["tables"]
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, List, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
|||
def left_join_df(
|
||||
left_df: pd.DataFrame,
|
||||
right_df: pd.DataFrame,
|
||||
join_keys: List[str],
|
||||
join_keys: list[str],
|
||||
) -> pd.DataFrame:
|
||||
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
|
||||
df.reset_index(inplace=True)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from flask_caching import Cache
|
||||
from pandas import DataFrame
|
||||
|
@ -37,7 +37,7 @@ config = app.config
|
|||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_cache: Dict[CacheRegion, Cache] = {
|
||||
_cache: dict[CacheRegion, Cache] = {
|
||||
CacheRegion.DEFAULT: cache_manager.cache,
|
||||
CacheRegion.DATA: cache_manager.data_cache,
|
||||
}
|
||||
|
@ -53,17 +53,17 @@ class QueryCacheManager:
|
|||
self,
|
||||
df: DataFrame = DataFrame(),
|
||||
query: str = "",
|
||||
annotation_data: Optional[Dict[str, Any]] = None,
|
||||
applied_template_filters: Optional[List[str]] = None,
|
||||
applied_filter_columns: Optional[List[Column]] = None,
|
||||
rejected_filter_columns: Optional[List[Column]] = None,
|
||||
status: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
annotation_data: dict[str, Any] | None = None,
|
||||
applied_template_filters: list[str] | None = None,
|
||||
applied_filter_columns: list[Column] | None = None,
|
||||
rejected_filter_columns: list[Column] | None = None,
|
||||
status: str | None = None,
|
||||
error_message: str | None = None,
|
||||
is_loaded: bool = False,
|
||||
stacktrace: Optional[str] = None,
|
||||
is_cached: Optional[bool] = None,
|
||||
cache_dttm: Optional[str] = None,
|
||||
cache_value: Optional[Dict[str, Any]] = None,
|
||||
stacktrace: str | None = None,
|
||||
is_cached: bool | None = None,
|
||||
cache_dttm: str | None = None,
|
||||
cache_value: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.df = df
|
||||
self.query = query
|
||||
|
@ -85,10 +85,10 @@ class QueryCacheManager:
|
|||
self,
|
||||
key: str,
|
||||
query_result: QueryResult,
|
||||
annotation_data: Optional[Dict[str, Any]] = None,
|
||||
force_query: Optional[bool] = False,
|
||||
timeout: Optional[int] = None,
|
||||
datasource_uid: Optional[str] = None,
|
||||
annotation_data: dict[str, Any] | None = None,
|
||||
force_query: bool | None = False,
|
||||
timeout: int | None = None,
|
||||
datasource_uid: str | None = None,
|
||||
region: CacheRegion = CacheRegion.DEFAULT,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -136,11 +136,11 @@ class QueryCacheManager:
|
|||
@classmethod
|
||||
def get(
|
||||
cls,
|
||||
key: Optional[str],
|
||||
key: str | None,
|
||||
region: CacheRegion = CacheRegion.DEFAULT,
|
||||
force_query: Optional[bool] = False,
|
||||
force_cached: Optional[bool] = False,
|
||||
) -> "QueryCacheManager":
|
||||
force_query: bool | None = False,
|
||||
force_cached: bool | None = False,
|
||||
) -> QueryCacheManager:
|
||||
"""
|
||||
Initialize QueryCacheManager by query-cache key
|
||||
"""
|
||||
|
@ -190,10 +190,10 @@ class QueryCacheManager:
|
|||
|
||||
@staticmethod
|
||||
def set(
|
||||
key: Optional[str],
|
||||
value: Dict[str, Any],
|
||||
timeout: Optional[int] = None,
|
||||
datasource_uid: Optional[str] = None,
|
||||
key: str | None,
|
||||
value: dict[str, Any],
|
||||
timeout: int | None = None,
|
||||
datasource_uid: str | None = None,
|
||||
region: CacheRegion = CacheRegion.DEFAULT,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -204,7 +204,7 @@ class QueryCacheManager:
|
|||
|
||||
@staticmethod
|
||||
def delete(
|
||||
key: Optional[str],
|
||||
key: str | None,
|
||||
region: CacheRegion = CacheRegion.DEFAULT,
|
||||
) -> None:
|
||||
if key:
|
||||
|
@ -212,7 +212,7 @@ class QueryCacheManager:
|
|||
|
||||
@staticmethod
|
||||
def has(
|
||||
key: Optional[str],
|
||||
key: str | None,
|
||||
region: CacheRegion = CacheRegion.DEFAULT,
|
||||
) -> bool:
|
||||
return bool(_cache[region].get(key)) if key else False
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, cast, Dict, Optional, Tuple
|
||||
from typing import Any, cast
|
||||
|
||||
from superset import app
|
||||
from superset.common.query_object import QueryObject
|
||||
|
@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until
|
|||
|
||||
|
||||
def get_since_until_from_time_range(
|
||||
time_range: Optional[str] = None,
|
||||
time_shift: Optional[str] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
time_range: str | None = None,
|
||||
time_shift: str | None = None,
|
||||
extras: dict[str, Any] | None = None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
return get_since_until(
|
||||
relative_start=(extras or {}).get(
|
||||
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
|
||||
|
@ -45,7 +45,7 @@ def get_since_until_from_time_range(
|
|||
# pylint: disable=invalid-name
|
||||
def get_since_until_from_query_object(
|
||||
query_object: QueryObject,
|
||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
this function will return since and until by tuple if
|
||||
1) the time_range is in the query object.
|
||||
|
|
|
@ -33,20 +33,7 @@ import sys
|
|||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
|
||||
|
||||
import pkg_resources
|
||||
from cachelib.base import BaseCache
|
||||
|
@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename(
|
|||
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
|
||||
|
||||
|
||||
def _try_json_readversion(filepath: str) -> Optional[str]:
|
||||
def _try_json_readversion(filepath: str) -> str | None:
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
with open(filepath) as f:
|
||||
return json.load(f).get("version")
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
|
||||
def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
|
||||
def _try_json_readsha(filepath: str, length: int) -> str | None:
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
with open(filepath) as f:
|
||||
return json.load(f).get("GIT_SHA")[:length]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False
|
|||
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
|
||||
|
||||
# Configuration for scheduling queries from SQL Lab.
|
||||
SCHEDULED_QUERIES: Dict[str, Any] = {}
|
||||
SCHEDULED_QUERIES: dict[str, Any] = {}
|
||||
|
||||
# ------------------------------
|
||||
# GLOBALS FOR APP Builder
|
||||
|
@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None
|
|||
LOGO_TOOLTIP = ""
|
||||
|
||||
# Specify any text that should appear to the right of the logo
|
||||
LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = ""
|
||||
LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
|
||||
|
||||
# Enables SWAGGER UI for superset openapi spec
|
||||
# ex: http://localhost:8080/swagger/v1
|
||||
|
@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB
|
|||
# Grant public role the same set of permissions as for a selected builtin role.
|
||||
# This is useful if one wants to enable anonymous users to view
|
||||
# dashboards. Explicit grant on specific datasets is still required.
|
||||
PUBLIC_ROLE_LIKE: Optional[str] = None
|
||||
PUBLIC_ROLE_LIKE: str | None = None
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Babel config for translations
|
||||
|
@ -390,8 +377,8 @@ LANGUAGES = {}
|
|||
class D3Format(TypedDict, total=False):
|
||||
decimal: str
|
||||
thousands: str
|
||||
grouping: List[int]
|
||||
currency: List[str]
|
||||
grouping: list[int]
|
||||
currency: list[str]
|
||||
|
||||
|
||||
D3_FORMAT: D3Format = {}
|
||||
|
@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {}
|
|||
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
|
||||
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
|
||||
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
|
||||
DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
||||
DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# Experimental feature introducing a client (browser) cache
|
||||
"CLIENT_CACHE": False, # deprecated
|
||||
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
|
||||
|
@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update(
|
|||
)
|
||||
|
||||
# This is merely a default.
|
||||
FEATURE_FLAGS: Dict[str, bool] = {}
|
||||
FEATURE_FLAGS: dict[str, bool] = {}
|
||||
|
||||
# A function that receives a dict of all feature flags
|
||||
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
|
||||
|
@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {}
|
|||
# if hasattr(g, "user") and g.user.is_active:
|
||||
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
|
||||
# return feature_flags_dict
|
||||
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
|
||||
GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None
|
||||
# A function that receives a feature flag name and an optional default value.
|
||||
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
|
||||
# evaluation of all feature flags when just evaluating a single one.
|
||||
|
@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] =
|
|||
# Note that the default `get_feature_flags` will evaluate each feature with this
|
||||
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
|
||||
# and IS_FEATURE_ENABLED_FUNC in conjunction.
|
||||
IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
|
||||
IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
|
||||
# A function that expands/overrides the frontend `bootstrap_data.common` object.
|
||||
# Can be used to implement custom frontend functionality,
|
||||
# or dynamically change certain configs.
|
||||
|
@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
|
|||
# Takes as a parameter the common bootstrap payload before transformations.
|
||||
# Returns a dict containing data that should be added or overridden to the payload.
|
||||
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
||||
[Dict[str, Any]], Dict[str, Any]
|
||||
[dict[str, Any]], dict[str, Any]
|
||||
] = lambda data: {} # default: empty dict
|
||||
|
||||
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
|
||||
|
@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
|||
# }]
|
||||
|
||||
# This is merely a default
|
||||
EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
||||
EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = []
|
||||
|
||||
# THEME_OVERRIDES is used for adding custom theme to superset
|
||||
# example code for "My theme" custom scheme
|
||||
|
@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
|||
# }
|
||||
# }
|
||||
|
||||
THEME_OVERRIDES: Dict[str, Any] = {}
|
||||
THEME_OVERRIDES: dict[str, Any] = {}
|
||||
|
||||
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
|
||||
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
|
||||
|
@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
|
|||
# }]
|
||||
|
||||
# This is merely a default
|
||||
EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
||||
EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = []
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Thumbnail config (behind feature flag)
|
||||
|
@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
|||
# `superset.tasks.types.ExecutorType` for a full list of executor options.
|
||||
# To always use a fixed user account, use the following configuration:
|
||||
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
|
||||
THUMBNAIL_SELENIUM_USER: Optional[str] = "admin"
|
||||
THUMBNAIL_SELENIUM_USER: str | None = "admin"
|
||||
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
|
||||
|
||||
# By default, thumbnail digests are calculated based on various parameters in the
|
||||
|
@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
|
|||
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
|
||||
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
|
||||
# and return the final digest string:
|
||||
THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[
|
||||
THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | (
|
||||
Callable[[Dashboard, ExecutorType, str], str]
|
||||
] = None
|
||||
THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None
|
||||
) = None
|
||||
THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None
|
||||
|
||||
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
|
||||
"CACHE_TYPE": "NullCache",
|
||||
|
@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False
|
|||
|
||||
# CORS Options
|
||||
ENABLE_CORS = False
|
||||
CORS_OPTIONS: Dict[Any, Any] = {}
|
||||
CORS_OPTIONS: dict[Any, Any] = {}
|
||||
|
||||
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
|
||||
# Disabling this option is not recommended for security reasons. If you wish to allow
|
||||
|
@ -736,7 +723,7 @@ HTML_SANITIZATION = True
|
|||
# }
|
||||
# }
|
||||
# Be careful when extending the default schema to avoid XSS attacks.
|
||||
HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {}
|
||||
HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {}
|
||||
|
||||
# Chrome allows up to 6 open connections per domain at a time. When there are more
|
||||
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
|
||||
|
@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"}
|
|||
# time grains in superset/db_engine_specs/base.py).
|
||||
# For example: to disable 1 second time grain:
|
||||
# TIME_GRAIN_DENYLIST = ['PT1S']
|
||||
TIME_GRAIN_DENYLIST: List[str] = []
|
||||
TIME_GRAIN_DENYLIST: list[str] = []
|
||||
|
||||
# Additional time grains to be supported using similar definitions as in
|
||||
# superset/db_engine_specs/base.py.
|
||||
# For example: To add a new 2 second time grain:
|
||||
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
|
||||
TIME_GRAIN_ADDONS: Dict[str, str] = {}
|
||||
TIME_GRAIN_ADDONS: dict[str, str] = {}
|
||||
|
||||
# Implementation of additional time grains per engine.
|
||||
# The column to be truncated is denoted `{col}` in the expression.
|
||||
|
@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {}
|
|||
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
|
||||
# }
|
||||
# }
|
||||
TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
|
||||
TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {}
|
||||
|
||||
# ---------------------------------------------------
|
||||
# List of viz_types not allowed in your environment
|
||||
|
@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
|
|||
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
|
||||
# ---------------------------------------------------
|
||||
|
||||
VIZ_TYPE_DENYLIST: List[str] = []
|
||||
VIZ_TYPE_DENYLIST: list[str] = []
|
||||
|
||||
# --------------------------------------------------
|
||||
# Modules, datasources and middleware to be registered
|
||||
|
@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
|
|||
("superset.connectors.sqla.models", ["SqlaTable"]),
|
||||
]
|
||||
)
|
||||
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
|
||||
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
|
||||
ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {}
|
||||
ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = []
|
||||
|
||||
# 1) https://docs.python-guide.org/writing/logging/
|
||||
# 2) https://docs.python.org/2/library/logging.config.html
|
||||
|
@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name
|
|||
# within the app
|
||||
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
|
||||
# override anything set within the app
|
||||
DEFAULT_HTTP_HEADERS: Dict[str, Any] = {}
|
||||
OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {}
|
||||
HTTP_HEADERS: Dict[str, Any] = {}
|
||||
DEFAULT_HTTP_HEADERS: dict[str, Any] = {}
|
||||
OVERRIDE_HTTP_HEADERS: dict[str, Any] = {}
|
||||
HTTP_HEADERS: dict[str, Any] = {}
|
||||
|
||||
# The db id here results in selecting this one as a default in SQL Lab
|
||||
DEFAULT_DB_ID = None
|
||||
|
@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
|
|||
# return out
|
||||
#
|
||||
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
|
||||
QUERY_COST_FORMATTERS_BY_ENGINE: Dict[
|
||||
str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
|
||||
QUERY_COST_FORMATTERS_BY_ENGINE: dict[
|
||||
str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
|
||||
] = {}
|
||||
|
||||
# Flag that controls if limit should be enforced on the CTA (create table as queries).
|
||||
|
@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False
|
|||
# else:
|
||||
# return f'tmp_{schema}'
|
||||
# Function accepts database object, user object, schema name and sql that will be run.
|
||||
SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
|
||||
SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | (
|
||||
Callable[[Database, models.User, str, str], str]
|
||||
] = None
|
||||
) = None
|
||||
|
||||
# If enabled, it can be used to store the results of long-running queries
|
||||
# in SQL Lab by using the "Run Async" button/feature
|
||||
RESULTS_BACKEND: Optional[BaseCache] = None
|
||||
RESULTS_BACKEND: BaseCache | None = None
|
||||
|
||||
# Use PyArrow and MessagePack for async query results serialization,
|
||||
# rather than JSON. This feature requires additional testing from the
|
||||
|
@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
|
|||
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
||||
database: Database,
|
||||
user: models.User, # pylint: disable=unused-argument
|
||||
schema: Optional[str],
|
||||
schema: str | None,
|
||||
) -> str:
|
||||
# Note the final empty path enforces a trailing slash.
|
||||
return os.path.join(
|
||||
|
@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
|||
|
||||
# The namespace within hive where the tables created from
|
||||
# uploading CSVs will be stored.
|
||||
UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
|
||||
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
|
||||
|
||||
# Function that computes the allowed schemas for the CSV uploads.
|
||||
# Allowed schemas will be a union of schemas_allowed_for_file_upload
|
||||
# db configuration and a result of this function.
|
||||
|
||||
# mypy doesn't catch that if case ensures list content being always str
|
||||
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = (
|
||||
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = (
|
||||
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
|
||||
if UPLOADED_CSV_HIVE_NAMESPACE
|
||||
else []
|
||||
|
@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
|
|||
# It's important to make sure that the objects exposed (as well as objects attached
|
||||
# to those objets) are harmless. We recommend only exposing simple/pure functions that
|
||||
# return native types.
|
||||
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
||||
JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
# A dictionary of macro template processors (by engine) that gets merged into global
|
||||
# template processors. The existing template processors get updated with this
|
||||
|
@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
|||
# dictionary. The customized addons don't necessarily need to use Jinja templating
|
||||
# language. This allows you to define custom logic to process templates on a per-engine
|
||||
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
|
||||
CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
|
||||
CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {}
|
||||
|
||||
# Roles that are controlled by the API / Superset and should not be changes
|
||||
# by humans.
|
||||
|
@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
|
|||
|
||||
# Integrate external Blueprints to the app by passing them to your
|
||||
# configuration. These blueprints will get integrated in the app
|
||||
BLUEPRINTS: List[Blueprint] = []
|
||||
BLUEPRINTS: list[Blueprint] = []
|
||||
|
||||
# Provide a callable that receives a tracking_url and returns another
|
||||
# URL. This is used to translate internal Hadoop job tracker URL
|
||||
|
@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url
|
|||
|
||||
|
||||
# customize the polling time of each engine
|
||||
DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {}
|
||||
DB_POLL_INTERVAL_SECONDS: dict[str, int] = {}
|
||||
|
||||
# Interval between consecutive polls when using Presto Engine
|
||||
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
|
||||
|
@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
|
|||
# "another_auth_method": auth_method,
|
||||
# },
|
||||
# }
|
||||
ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
|
||||
ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||
|
||||
# The id of a template dashboard that should be copied to every new user
|
||||
DASHBOARD_TEMPLATE_ID = None
|
||||
|
@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
|
|||
# Owners, filters for created_by, etc.
|
||||
# The users can also be excluded by overriding the get_exclude_users_from_lists method
|
||||
# in security manager
|
||||
EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None
|
||||
EXCLUDE_USERS_FROM_LISTS: list[str] | None = None
|
||||
|
||||
# For database connections, this dictionary will remove engines from the available
|
||||
# list/dropdown if you do not want these dbs to show as available.
|
||||
# The available list is generated by driver installed, and some engines have multiple
|
||||
# drivers.
|
||||
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
|
||||
DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {}
|
||||
DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {}
|
||||
|
||||
# This auth provider is used by background (offline) tasks that need to access
|
||||
# protected resources. Can be overridden by end users in order to support
|
||||
|
@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
|
|||
# ExecutorType.OWNER,
|
||||
# ExecutorType.SELENIUM,
|
||||
# ]
|
||||
ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER]
|
||||
ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER]
|
||||
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
|
||||
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
|
||||
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
|
||||
|
@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
|
|||
EMAIL_REPORTS_CTA = "Explore in Superset"
|
||||
|
||||
# Slack API token for the superset reports, either string or callable
|
||||
SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None
|
||||
SLACK_API_TOKEN: Callable[[], str] | str | None = None
|
||||
SLACK_PROXY = None
|
||||
|
||||
# The webdriver to use for generating reports. Use one of the following
|
||||
|
@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = {
|
|||
WEBDRIVER_AUTH_FUNC = None
|
||||
|
||||
# Any config options to be passed as-is to the webdriver
|
||||
WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"}
|
||||
WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"}
|
||||
|
||||
# Additional args to be passed as arguments to the config object
|
||||
# Note: If using Chrome, you'll want to add the "--marionette" arg.
|
||||
|
@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = {
|
|||
# displayed prominently in the "Add Database" dialog. You should
|
||||
# use the "engine_name" attribute of the corresponding DB engine spec
|
||||
# in `superset/db_engine_specs/`.
|
||||
PREFERRED_DATABASES: List[str] = [
|
||||
PREFERRED_DATABASES: list[str] = [
|
||||
"PostgreSQL",
|
||||
"Presto",
|
||||
"MySQL",
|
||||
|
@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = {
|
|||
#
|
||||
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
|
||||
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
|
||||
SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax"
|
||||
SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax"
|
||||
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
|
||||
SESSION_PROTECTION = "strong"
|
||||
|
||||
|
@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
|
|||
# Path used to store SSL certificates that are generated when using custom certs.
|
||||
# Defaults to temporary directory.
|
||||
# Example: SSL_CERT_PATH = "/certs"
|
||||
SSL_CERT_PATH: Optional[str] = None
|
||||
SSL_CERT_PATH: str | None = None
|
||||
|
||||
# SQLA table mutator, every time we fetch the metadata for a certain table
|
||||
# (superset.connectors.sqla.models.SqlaTable), we call this hook
|
||||
|
@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
|
|||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
|
||||
Literal["None", "Lax", "Strict"]
|
||||
] = None
|
||||
) = None
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
|
||||
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
|
||||
|
@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256"
|
|||
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
|
||||
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
|
||||
# Guest token audience for the embedded superset, either string or callable
|
||||
GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
|
||||
GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None
|
||||
|
||||
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
|
||||
# be memoized to aid with performance, i.e.,
|
||||
|
@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
|
|||
# cache_manager.cache.delete_memoized(func)
|
||||
# cache_manager.cache.set(name, code, timeout=0)
|
||||
#
|
||||
DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None
|
||||
DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None
|
||||
|
||||
# Do not show user info or profile in the menu
|
||||
MENU_HIDE_USER_INFO = False
|
||||
|
@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False
|
|||
ENABLE_BROAD_ACTIVITY_ACCESS = True
|
||||
|
||||
# the advanced data type key should correspond to that set in the column metadata
|
||||
ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
|
||||
ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = {
|
||||
"internet_address": internet_address,
|
||||
"port": internet_port,
|
||||
}
|
||||
|
@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
|
|||
# "Xyz",
|
||||
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
|
||||
# )
|
||||
WELCOME_PAGE_LAST_TAB: Union[
|
||||
Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]]
|
||||
] = "all"
|
||||
WELCOME_PAGE_LAST_TAB: (
|
||||
Literal["examples", "all"] | tuple[str, list[dict[str, Any]]]
|
||||
) = "all"
|
||||
|
||||
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
|
||||
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)
|
||||
|
|
|
@ -16,21 +16,12 @@
|
|||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import json
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import gettext as __
|
||||
|
@ -89,23 +80,23 @@ class BaseDatasource(
|
|||
# ---------------------------------------------------------------
|
||||
# class attributes to define when deriving BaseDatasource
|
||||
# ---------------------------------------------------------------
|
||||
__tablename__: Optional[str] = None # {connector_name}_datasource
|
||||
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
|
||||
__tablename__: str | None = None # {connector_name}_datasource
|
||||
baselink: str | None = None # url portion pointing to ModelView endpoint
|
||||
|
||||
@property
|
||||
def column_class(self) -> Type["BaseColumn"]:
|
||||
def column_class(self) -> type[BaseColumn]:
|
||||
# link to derivative of BaseColumn
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def metric_class(self) -> Type["BaseMetric"]:
|
||||
def metric_class(self) -> type[BaseMetric]:
|
||||
# link to derivative of BaseMetric
|
||||
raise NotImplementedError()
|
||||
|
||||
owner_class: Optional[User] = None
|
||||
owner_class: User | None = None
|
||||
|
||||
# Used to do code highlighting when displaying the query in the UI
|
||||
query_language: Optional[str] = None
|
||||
query_language: str | None = None
|
||||
|
||||
# Only some datasources support Row Level Security
|
||||
is_rls_supported: bool = False
|
||||
|
@ -131,9 +122,9 @@ class BaseDatasource(
|
|||
is_managed_externally = Column(Boolean, nullable=False, default=False)
|
||||
external_url = Column(Text, nullable=True)
|
||||
|
||||
sql: Optional[str] = None
|
||||
owners: List[User]
|
||||
update_from_object_fields: List[str]
|
||||
sql: str | None = None
|
||||
owners: list[User]
|
||||
update_from_object_fields: list[str]
|
||||
|
||||
extra_import_fields = ["is_managed_externally", "external_url"]
|
||||
|
||||
|
@ -142,7 +133,7 @@ class BaseDatasource(
|
|||
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
|
||||
|
||||
@property
|
||||
def owners_data(self) -> List[Dict[str, Any]]:
|
||||
def owners_data(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"first_name": o.first_name,
|
||||
|
@ -167,8 +158,8 @@ class BaseDatasource(
|
|||
),
|
||||
)
|
||||
|
||||
columns: List["BaseColumn"] = []
|
||||
metrics: List["BaseMetric"] = []
|
||||
columns: list[BaseColumn] = []
|
||||
metrics: list[BaseMetric] = []
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
@ -180,11 +171,11 @@ class BaseDatasource(
|
|||
return f"{self.id}__{self.type}"
|
||||
|
||||
@property
|
||||
def column_names(self) -> List[str]:
|
||||
def column_names(self) -> list[str]:
|
||||
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
|
||||
|
||||
@property
|
||||
def columns_types(self) -> Dict[str, str]:
|
||||
def columns_types(self) -> dict[str, str]:
|
||||
return {c.column_name: c.type for c in self.columns}
|
||||
|
||||
@property
|
||||
|
@ -196,26 +187,26 @@ class BaseDatasource(
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional[str]:
|
||||
def connection(self) -> str | None:
|
||||
"""String representing the context of the Datasource"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def schema(self) -> Optional[str]:
|
||||
def schema(self) -> str | None:
|
||||
"""String representing the schema of the Datasource (if it applies)"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def filterable_column_names(self) -> List[str]:
|
||||
def filterable_column_names(self) -> list[str]:
|
||||
return sorted([c.column_name for c in self.columns if c.filterable])
|
||||
|
||||
@property
|
||||
def dttm_cols(self) -> List[str]:
|
||||
def dttm_cols(self) -> list[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return "/{}/edit/{}".format(self.baselink, self.id)
|
||||
return f"/{self.baselink}/edit/{self.id}"
|
||||
|
||||
@property
|
||||
def explore_url(self) -> str:
|
||||
|
@ -224,10 +215,10 @@ class BaseDatasource(
|
|||
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
|
||||
|
||||
@property
|
||||
def column_formats(self) -> Dict[str, Optional[str]]:
|
||||
def column_formats(self) -> dict[str, str | None]:
|
||||
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
|
||||
|
||||
def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
|
||||
def add_missing_metrics(self, metrics: list[BaseMetric]) -> None:
|
||||
existing_metrics = {m.metric_name for m in self.metrics}
|
||||
for metric in metrics:
|
||||
if metric.metric_name not in existing_metrics:
|
||||
|
@ -235,7 +226,7 @@ class BaseDatasource(
|
|||
self.metrics.append(metric)
|
||||
|
||||
@property
|
||||
def short_data(self) -> Dict[str, Any]:
|
||||
def short_data(self) -> dict[str, Any]:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
return {
|
||||
"edit_url": self.url,
|
||||
|
@ -249,11 +240,11 @@ class BaseDatasource(
|
|||
}
|
||||
|
||||
@property
|
||||
def select_star(self) -> Optional[str]:
|
||||
def select_star(self) -> str | None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def order_by_choices(self) -> List[Tuple[str, str]]:
|
||||
def order_by_choices(self) -> list[tuple[str, str]]:
|
||||
choices = []
|
||||
# self.column_names return sorted column_names
|
||||
for column_name in self.column_names:
|
||||
|
@ -267,7 +258,7 @@ class BaseDatasource(
|
|||
return choices
|
||||
|
||||
@property
|
||||
def verbose_map(self) -> Dict[str, str]:
|
||||
def verbose_map(self) -> dict[str, str]:
|
||||
verb_map = {"__timestamp": "Time"}
|
||||
verb_map.update(
|
||||
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
|
||||
|
@ -278,7 +269,7 @@ class BaseDatasource(
|
|||
return verb_map
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
return {
|
||||
# simple fields
|
||||
|
@ -313,8 +304,8 @@ class BaseDatasource(
|
|||
}
|
||||
|
||||
def data_for_slices( # pylint: disable=too-many-locals
|
||||
self, slices: List[Slice]
|
||||
) -> Dict[str, Any]:
|
||||
self, slices: list[Slice]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
The representation of the datasource containing only the required data
|
||||
to render the provided slices.
|
||||
|
@ -381,8 +372,8 @@ class BaseDatasource(
|
|||
if metric["metric_name"] in metric_names
|
||||
]
|
||||
|
||||
filtered_columns: List[Column] = []
|
||||
column_types: Set[GenericDataType] = set()
|
||||
filtered_columns: list[Column] = []
|
||||
column_types: set[GenericDataType] = set()
|
||||
for column in data["columns"]:
|
||||
generic_type = column.get("type_generic")
|
||||
if generic_type is not None:
|
||||
|
@ -413,18 +404,18 @@ class BaseDatasource(
|
|||
|
||||
@staticmethod
|
||||
def filter_values_handler( # pylint: disable=too-many-arguments
|
||||
values: Optional[FilterValues],
|
||||
values: FilterValues | None,
|
||||
operator: str,
|
||||
target_generic_type: GenericDataType,
|
||||
target_native_type: Optional[str] = None,
|
||||
target_native_type: str | None = None,
|
||||
is_list_target: bool = False,
|
||||
db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
|
||||
db_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[FilterValues]:
|
||||
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
|
||||
db_extra: dict[str, Any] | None = None,
|
||||
) -> FilterValues | None:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
|
||||
def handle_single_value(value: FilterValue | None) -> FilterValue | None:
|
||||
if operator == utils.FilterOperator.TEMPORAL_RANGE:
|
||||
return value
|
||||
if (
|
||||
|
@ -464,7 +455,7 @@ class BaseDatasource(
|
|||
values = values[0] if values else None
|
||||
return values
|
||||
|
||||
def external_metadata(self) -> List[Dict[str, str]]:
|
||||
def external_metadata(self) -> list[dict[str, str]]:
|
||||
"""Returns column information from the external system"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -483,7 +474,7 @@ class BaseDatasource(
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||
"""Given a column, returns an iterable of distinct values
|
||||
|
||||
This is used to populate the dropdown showing a list of
|
||||
|
@ -494,7 +485,7 @@ class BaseDatasource(
|
|||
def default_query(qry: Query) -> Query:
|
||||
return qry
|
||||
|
||||
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
|
||||
def get_column(self, column_name: str | None) -> BaseColumn | None:
|
||||
if not column_name:
|
||||
return None
|
||||
for col in self.columns:
|
||||
|
@ -504,11 +495,11 @@ class BaseDatasource(
|
|||
|
||||
@staticmethod
|
||||
def get_fk_many_from_list(
|
||||
object_list: List[Any],
|
||||
fkmany: List[Column],
|
||||
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
|
||||
object_list: list[Any],
|
||||
fkmany: list[Column],
|
||||
fkmany_class: builtins.type[BaseColumn | BaseMetric],
|
||||
key_attr: str,
|
||||
) -> List[Column]:
|
||||
) -> list[Column]:
|
||||
"""Update ORM one-to-many list from object list
|
||||
|
||||
Used for syncing metrics and columns using the same code"""
|
||||
|
@ -541,7 +532,7 @@ class BaseDatasource(
|
|||
fkmany += new_fks
|
||||
return fkmany
|
||||
|
||||
def update_from_object(self, obj: Dict[str, Any]) -> None:
|
||||
def update_from_object(self, obj: dict[str, Any]) -> None:
|
||||
"""Update datasource from a data structure
|
||||
|
||||
The UI's table editor crafts a complex data structure that
|
||||
|
@ -578,7 +569,7 @@ class BaseDatasource(
|
|||
|
||||
def get_extra_cache_keys( # pylint: disable=no-self-use
|
||||
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
|
||||
) -> List[Hashable]:
|
||||
) -> list[Hashable]:
|
||||
"""If a datasource needs to provide additional keys for calculation of
|
||||
cache keys, those can be provided via this method
|
||||
|
||||
|
@ -607,14 +598,14 @@ class BaseDatasource(
|
|||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
||||
) -> Optional["BaseDatasource"]:
|
||||
) -> BaseDatasource | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||
"""Interface for column"""
|
||||
|
||||
__tablename__: Optional[str] = None # {connector_name}_column
|
||||
__tablename__: str | None = None # {connector_name}_column
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
column_name = Column(String(255), nullable=False)
|
||||
|
@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
|||
is_dttm = None
|
||||
|
||||
# [optional] Set this to support import/export functionality
|
||||
export_fields: List[Any] = []
|
||||
export_fields: list[Any] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.column_name)
|
||||
|
@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
|||
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
|
||||
|
||||
@property
|
||||
def type_generic(self) -> Optional[utils.GenericDataType]:
|
||||
def type_generic(self) -> utils.GenericDataType | None:
|
||||
if self.is_string:
|
||||
return utils.GenericDataType.STRING
|
||||
if self.is_boolean:
|
||||
|
@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
attrs = (
|
||||
"id",
|
||||
"column_name",
|
||||
|
@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
|||
class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
||||
"""Interface for Metrics"""
|
||||
|
||||
__tablename__: Optional[str] = None # {connector_name}_metric
|
||||
__tablename__: str | None = None # {connector_name}_metric
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
metric_name = Column(String(255), nullable=False)
|
||||
|
@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
|||
"""
|
||||
|
||||
@property
|
||||
def perm(self) -> Optional[str]:
|
||||
def perm(self) -> str | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
|
@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
attrs = (
|
||||
"id",
|
||||
"metric_name",
|
||||
|
|
|
@ -22,21 +22,10 @@ import json
|
|||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import dateutil.parser
|
||||
import numpy as np
|
||||
|
@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
|
|||
|
||||
@dataclass
|
||||
class MetadataResult:
|
||||
added: List[str] = field(default_factory=list)
|
||||
removed: List[str] = field(default_factory=list)
|
||||
modified: List[str] = field(default_factory=list)
|
||||
added: list[str] = field(default_factory=list)
|
||||
removed: list[str] = field(default_factory=list)
|
||||
modified: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class AnnotationDatasource(BaseDatasource):
|
||||
|
@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource):
|
|||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
__tablename__ = "table_columns"
|
||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||
table: Mapped["SqlaTable"] = relationship(
|
||||
table: Mapped[SqlaTable] = relationship(
|
||||
"SqlaTable",
|
||||
back_populates="columns",
|
||||
)
|
||||
|
@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
return self.type_generic == GenericDataType.TEMPORAL
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[BaseEngineSpec]:
|
||||
def db_engine_spec(self) -> type[BaseEngineSpec]:
|
||||
return self.table.db_engine_spec
|
||||
|
||||
@property
|
||||
def db_extra(self) -> Dict[str, Any]:
|
||||
def db_extra(self) -> dict[str, Any]:
|
||||
return self.table.database.get_extra()
|
||||
|
||||
@property
|
||||
def type_generic(self) -> Optional[utils.GenericDataType]:
|
||||
def type_generic(self) -> utils.GenericDataType | None:
|
||||
if self.is_dttm:
|
||||
return GenericDataType.TEMPORAL
|
||||
|
||||
|
@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
|
||||
def get_sqla_col(
|
||||
self,
|
||||
label: Optional[str] = None,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
label: str | None = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> Column:
|
||||
label = label or self.column_name
|
||||
db_engine_spec = self.db_engine_spec
|
||||
|
@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
|
||||
def get_timestamp_expression(
|
||||
self,
|
||||
time_grain: Optional[str],
|
||||
label: Optional[str] = None,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> Union[TimestampExpression, Label]:
|
||||
time_grain: str | None,
|
||||
label: str | None = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> TimestampExpression | Label:
|
||||
"""
|
||||
Return a SQLAlchemy Core element representation of self to be used in a query.
|
||||
|
||||
|
@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
|||
return self.table.make_sqla_column_compatible(time_expr, label)
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
attrs = (
|
||||
"id",
|
||||
"column_name",
|
||||
|
@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
__tablename__ = "sql_metrics"
|
||||
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
|
||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||
table: Mapped["SqlaTable"] = relationship(
|
||||
table: Mapped[SqlaTable] = relationship(
|
||||
"SqlaTable",
|
||||
back_populates="metrics",
|
||||
)
|
||||
|
@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
|
||||
def get_sqla_col(
|
||||
self,
|
||||
label: Optional[str] = None,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
label: str | None = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> Column:
|
||||
label = label or self.metric_name
|
||||
expression = self.expression
|
||||
|
@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
|
||||
@property
|
||||
def perm(self) -> Optional[str]:
|
||||
def perm(self) -> str | None:
|
||||
return (
|
||||
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
|
||||
obj=self, parent_name=self.table.full_name
|
||||
|
@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
|||
else None
|
||||
)
|
||||
|
||||
def get_perm(self) -> Optional[str]:
|
||||
def get_perm(self) -> str | None:
|
||||
return self.perm
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
attrs = (
|
||||
"is_certified",
|
||||
"certified_by",
|
||||
|
@ -473,11 +462,11 @@ sqlatable_user = Table(
|
|||
|
||||
|
||||
def _process_sql_expression(
|
||||
expression: Optional[str],
|
||||
expression: str | None,
|
||||
database_id: int,
|
||||
schema: str,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
) -> Optional[str]:
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> str | None:
|
||||
if template_processor and expression:
|
||||
expression = template_processor.process_template(expression)
|
||||
if expression:
|
||||
|
@ -501,12 +490,12 @@ class SqlaTable(
|
|||
type = "table"
|
||||
query_language = "sql"
|
||||
is_rls_supported = True
|
||||
columns: Mapped[List[TableColumn]] = relationship(
|
||||
columns: Mapped[list[TableColumn]] = relationship(
|
||||
TableColumn,
|
||||
back_populates="table",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
metrics: Mapped[List[SqlMetric]] = relationship(
|
||||
metrics: Mapped[list[SqlMetric]] = relationship(
|
||||
SqlMetric,
|
||||
back_populates="table",
|
||||
cascade="all, delete-orphan",
|
||||
|
@ -577,11 +566,11 @@ class SqlaTable(
|
|||
return self.name
|
||||
|
||||
@property
|
||||
def db_extra(self) -> Dict[str, Any]:
|
||||
def db_extra(self) -> dict[str, Any]:
|
||||
return self.database.get_extra()
|
||||
|
||||
@staticmethod
|
||||
def _apply_cte(sql: str, cte: Optional[str]) -> str:
|
||||
def _apply_cte(sql: str, cte: str | None) -> str:
|
||||
"""
|
||||
Append a CTE before the SELECT statement if defined
|
||||
|
||||
|
@ -594,7 +583,7 @@ class SqlaTable(
|
|||
return sql
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[BaseEngineSpec]:
|
||||
def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
|
||||
return self.database.db_engine_spec
|
||||
|
||||
@property
|
||||
|
@ -637,9 +626,9 @@ class SqlaTable(
|
|||
cls,
|
||||
session: Session,
|
||||
datasource_name: str,
|
||||
schema: Optional[str],
|
||||
schema: str | None,
|
||||
database_name: str,
|
||||
) -> Optional[SqlaTable]:
|
||||
) -> SqlaTable | None:
|
||||
schema = schema or None
|
||||
query = (
|
||||
session.query(cls)
|
||||
|
@ -660,7 +649,7 @@ class SqlaTable(
|
|||
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
|
||||
return Markup(anchor)
|
||||
|
||||
def get_schema_perm(self) -> Optional[str]:
|
||||
def get_schema_perm(self) -> str | None:
|
||||
"""Returns schema permission if present, database one otherwise."""
|
||||
return security_manager.get_schema_perm(self.database, self.schema)
|
||||
|
||||
|
@ -685,18 +674,18 @@ class SqlaTable(
|
|||
)
|
||||
|
||||
@property
|
||||
def dttm_cols(self) -> List[str]:
|
||||
def dttm_cols(self) -> list[str]:
|
||||
l = [c.column_name for c in self.columns if c.is_dttm]
|
||||
if self.main_dttm_col and self.main_dttm_col not in l:
|
||||
l.append(self.main_dttm_col)
|
||||
return l
|
||||
|
||||
@property
|
||||
def num_cols(self) -> List[str]:
|
||||
def num_cols(self) -> list[str]:
|
||||
return [c.column_name for c in self.columns if c.is_numeric]
|
||||
|
||||
@property
|
||||
def any_dttm_col(self) -> Optional[str]:
|
||||
def any_dttm_col(self) -> str | None:
|
||||
cols = self.dttm_cols
|
||||
return cols[0] if cols else None
|
||||
|
||||
|
@ -713,7 +702,7 @@ class SqlaTable(
|
|||
def sql_url(self) -> str:
|
||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||
|
||||
def external_metadata(self) -> List[Dict[str, str]]:
|
||||
def external_metadata(self) -> list[dict[str, str]]:
|
||||
# todo(yongjie): create a physical table column type in a separate PR
|
||||
if self.sql:
|
||||
return get_virtual_table_metadata(dataset=self) # type: ignore
|
||||
|
@ -724,14 +713,14 @@ class SqlaTable(
|
|||
)
|
||||
|
||||
@property
|
||||
def time_column_grains(self) -> Dict[str, Any]:
|
||||
def time_column_grains(self) -> dict[str, Any]:
|
||||
return {
|
||||
"time_columns": self.dttm_cols,
|
||||
"time_grains": [grain.name for grain in self.database.grains()],
|
||||
}
|
||||
|
||||
@property
|
||||
def select_star(self) -> Optional[str]:
|
||||
def select_star(self) -> str | None:
|
||||
# show_cols and latest_partition set to false to avoid
|
||||
# the expensive cost of inspecting the DB
|
||||
return self.database.select_star(
|
||||
|
@ -739,20 +728,20 @@ class SqlaTable(
|
|||
)
|
||||
|
||||
@property
|
||||
def health_check_message(self) -> Optional[str]:
|
||||
def health_check_message(self) -> str | None:
|
||||
check = config["DATASET_HEALTH_CHECK"]
|
||||
return check(self) if check else None
|
||||
|
||||
@property
|
||||
def granularity_sqla(self) -> List[Tuple[Any, Any]]:
|
||||
def granularity_sqla(self) -> list[tuple[Any, Any]]:
|
||||
return utils.choicify(self.dttm_cols)
|
||||
|
||||
@property
|
||||
def time_grain_sqla(self) -> List[Tuple[Any, Any]]:
|
||||
def time_grain_sqla(self) -> list[tuple[Any, Any]]:
|
||||
return [(g.duration, g.name) for g in self.database.grains() or []]
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
data_ = super().data
|
||||
if self.type == "table":
|
||||
data_["granularity_sqla"] = self.granularity_sqla
|
||||
|
@ -767,7 +756,7 @@ class SqlaTable(
|
|||
return data_
|
||||
|
||||
@property
|
||||
def extra_dict(self) -> Dict[str, Any]:
|
||||
def extra_dict(self) -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(self.extra)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
|
@ -775,7 +764,7 @@ class SqlaTable(
|
|||
|
||||
def get_fetch_values_predicate(
|
||||
self,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> TextClause:
|
||||
fetch_values_predicate = self.fetch_values_predicate
|
||||
if template_processor:
|
||||
|
@ -792,7 +781,7 @@ class SqlaTable(
|
|||
)
|
||||
) from ex
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||
"""Runs query against sqla to retrieve some
|
||||
sample values for the given column.
|
||||
"""
|
||||
|
@ -869,8 +858,8 @@ class SqlaTable(
|
|||
return tbl
|
||||
|
||||
def get_from_clause(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
|
||||
self, template_processor: BaseTemplateProcessor | None = None
|
||||
) -> tuple[TableClause | Alias, str | None]:
|
||||
"""
|
||||
Return where to select the columns and metrics from. Either a physical table
|
||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||
|
@ -899,7 +888,7 @@ class SqlaTable(
|
|||
return from_clause, cte
|
||||
|
||||
def get_rendered_sql(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
self, template_processor: BaseTemplateProcessor | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Render sql with template engine (Jinja).
|
||||
|
@ -928,8 +917,8 @@ class SqlaTable(
|
|||
def adhoc_metric_to_sqla(
|
||||
self,
|
||||
metric: AdhocMetric,
|
||||
columns_by_name: Dict[str, TableColumn],
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
columns_by_name: dict[str, TableColumn],
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> ColumnElement:
|
||||
"""
|
||||
Turn an adhoc metric into a sqlalchemy column.
|
||||
|
@ -946,7 +935,7 @@ class SqlaTable(
|
|||
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
|
||||
metric_column = metric.get("column") or {}
|
||||
column_name = cast(str, metric_column.get("column_name"))
|
||||
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
|
||||
table_column: TableColumn | None = columns_by_name.get(column_name)
|
||||
if table_column:
|
||||
sqla_column = table_column.get_sqla_col(
|
||||
template_processor=template_processor
|
||||
|
@ -971,7 +960,7 @@ class SqlaTable(
|
|||
self,
|
||||
col: AdhocColumn,
|
||||
force_type_check: bool = False,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> ColumnElement:
|
||||
"""
|
||||
Turn an adhoc column into a sqlalchemy column.
|
||||
|
@ -1021,7 +1010,7 @@ class SqlaTable(
|
|||
return self.make_sqla_column_compatible(sqla_column, label)
|
||||
|
||||
def make_sqla_column_compatible(
|
||||
self, sqla_col: ColumnElement, label: Optional[str] = None
|
||||
self, sqla_col: ColumnElement, label: str | None = None
|
||||
) -> ColumnElement:
|
||||
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
||||
:param sqla_col: sqlalchemy column instance
|
||||
|
@ -1038,7 +1027,7 @@ class SqlaTable(
|
|||
return sqla_col
|
||||
|
||||
def make_orderby_compatible(
|
||||
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
|
||||
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
|
||||
) -> None:
|
||||
"""
|
||||
If needed, make sure aliases for selected columns are not used in
|
||||
|
@ -1069,7 +1058,7 @@ class SqlaTable(
|
|||
def get_sqla_row_level_filters(
|
||||
self,
|
||||
template_processor: BaseTemplateProcessor,
|
||||
) -> List[TextClause]:
|
||||
) -> list[TextClause]:
|
||||
"""
|
||||
Return the appropriate row level security filters for this table and the
|
||||
current user. A custom username can be passed when the user is not present in the
|
||||
|
@ -1078,8 +1067,8 @@ class SqlaTable(
|
|||
:param template_processor: The template processor to apply to the filters.
|
||||
:returns: A list of SQL clauses to be ANDed together.
|
||||
"""
|
||||
all_filters: List[TextClause] = []
|
||||
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
|
||||
all_filters: list[TextClause] = []
|
||||
filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
|
||||
try:
|
||||
for filter_ in security_manager.get_rls_filters(self):
|
||||
clause = self.text(
|
||||
|
@ -1114,9 +1103,9 @@ class SqlaTable(
|
|||
def _get_series_orderby(
|
||||
self,
|
||||
series_limit_metric: Metric,
|
||||
metrics_by_name: Dict[str, SqlMetric],
|
||||
columns_by_name: Dict[str, TableColumn],
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
metrics_by_name: dict[str, SqlMetric],
|
||||
columns_by_name: dict[str, TableColumn],
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> Column:
|
||||
if utils.is_adhoc_metric(series_limit_metric):
|
||||
assert isinstance(series_limit_metric, dict)
|
||||
|
@ -1138,8 +1127,8 @@ class SqlaTable(
|
|||
self,
|
||||
row: pd.Series,
|
||||
dimension: str,
|
||||
columns_by_name: Dict[str, TableColumn],
|
||||
) -> Union[str, int, float, bool, Text]:
|
||||
columns_by_name: dict[str, TableColumn],
|
||||
) -> str | int | float | bool | Text:
|
||||
"""
|
||||
Convert a prequery result type to its equivalent Python type.
|
||||
|
||||
|
@ -1159,7 +1148,7 @@ class SqlaTable(
|
|||
value = value.item()
|
||||
|
||||
column_ = columns_by_name[dimension]
|
||||
db_extra: Dict[str, Any] = self.database.get_extra()
|
||||
db_extra: dict[str, Any] = self.database.get_extra()
|
||||
|
||||
if column_.type and column_.is_temporal and isinstance(value, str):
|
||||
sql = self.db_engine_spec.convert_dttm(
|
||||
|
@ -1174,9 +1163,9 @@ class SqlaTable(
|
|||
def _get_top_groups(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
dimensions: List[str],
|
||||
groupby_exprs: Dict[str, Any],
|
||||
columns_by_name: Dict[str, TableColumn],
|
||||
dimensions: list[str],
|
||||
groupby_exprs: dict[str, Any],
|
||||
columns_by_name: dict[str, TableColumn],
|
||||
) -> ColumnElement:
|
||||
groups = []
|
||||
for _unused, row in df.iterrows():
|
||||
|
@ -1201,7 +1190,7 @@ class SqlaTable(
|
|||
errors = None
|
||||
error_message = None
|
||||
|
||||
def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
|
||||
"""
|
||||
Some engines change the case or generate bespoke column names, either by
|
||||
default or due to lack of support for aliasing. This function ensures that
|
||||
|
@ -1283,7 +1272,7 @@ class SqlaTable(
|
|||
else self.columns
|
||||
)
|
||||
|
||||
old_columns_by_name: Dict[str, TableColumn] = {
|
||||
old_columns_by_name: dict[str, TableColumn] = {
|
||||
col.column_name: col for col in old_columns
|
||||
}
|
||||
results = MetadataResult(
|
||||
|
@ -1341,8 +1330,8 @@ class SqlaTable(
|
|||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List[SqlaTable]:
|
||||
schema: str | None = None,
|
||||
) -> list[SqlaTable]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
|
@ -1357,9 +1346,9 @@ class SqlaTable(
|
|||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
) -> List[SqlaTable]:
|
||||
permissions: set[str],
|
||||
schema_perms: set[str],
|
||||
) -> list[SqlaTable]:
|
||||
# TODO(hughhhh): add unit test
|
||||
return (
|
||||
session.query(cls)
|
||||
|
@ -1389,7 +1378,7 @@ class SqlaTable(
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
|
||||
def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
|
||||
qry = session.query(cls)
|
||||
qry = cls.default_query(qry)
|
||||
return qry.all()
|
||||
|
@ -1409,7 +1398,7 @@ class SqlaTable(
|
|||
:param query_obj: query object to analyze
|
||||
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
|
||||
"""
|
||||
templatable_statements: List[str] = []
|
||||
templatable_statements: list[str] = []
|
||||
if self.sql:
|
||||
templatable_statements.append(self.sql)
|
||||
if self.fetch_values_predicate:
|
||||
|
@ -1428,7 +1417,7 @@ class SqlaTable(
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||
"""
|
||||
The cache key of a SqlaTable needs to consider any keys added by the parent
|
||||
class and any keys added via `ExtraCache`.
|
||||
|
@ -1489,7 +1478,7 @@ class SqlaTable(
|
|||
|
||||
@staticmethod
|
||||
def update_column( # pylint: disable=unused-argument
|
||||
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
|
||||
mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
|
||||
) -> None:
|
||||
"""
|
||||
:param mapper: Unused.
|
||||
|
|
|
@ -17,19 +17,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable, Iterator
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
@ -58,8 +48,8 @@ if TYPE_CHECKING:
|
|||
def get_physical_table_metadata(
|
||||
database: Database,
|
||||
table_name: str,
|
||||
schema_name: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
schema_name: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Use SQLAlchemy inspector to get table metadata"""
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_dialect = database.get_dialect()
|
||||
|
@ -103,7 +93,7 @@ def get_physical_table_metadata(
|
|||
return cols
|
||||
|
||||
|
||||
def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
||||
def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
|
||||
"""Use SQLparser to get virtual dataset metadata"""
|
||||
if not dataset.sql:
|
||||
raise SupersetGenericDBErrorException(
|
||||
|
@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
|||
def get_columns_description(
|
||||
database: Database,
|
||||
query: str,
|
||||
) -> List[ResultSetColumnType]:
|
||||
) -> list[ResultSetColumnType]:
|
||||
db_engine_spec = database.db_engine_spec
|
||||
try:
|
||||
with database.get_raw_connection() as conn:
|
||||
|
@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str:
|
|||
|
||||
|
||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
|
||||
def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]:
|
||||
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
|
||||
|
||||
|
||||
|
@ -181,9 +171,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def find_cached_objects_in_session(
|
||||
session: Session,
|
||||
cls: Type[DeclarativeModel],
|
||||
ids: Optional[Iterable[int]] = None,
|
||||
uuids: Optional[Iterable[UUID]] = None,
|
||||
cls: type[DeclarativeModel],
|
||||
ids: Iterable[int] | None = None,
|
||||
uuids: Iterable[UUID] | None = None,
|
||||
) -> Iterator[DeclarativeModel]:
|
||||
"""Find known ORM instances in cached SQLA session states.
|
||||
|
||||
|
|
|
@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
|
|||
resp = super().edit(pk)
|
||||
if isinstance(resp, str):
|
||||
return resp
|
||||
return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk))
|
||||
return redirect(f"/explore/?datasource_type=table&datasource_id={pk}")
|
||||
|
||||
@expose("/list/")
|
||||
@has_access
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.css_templates.commands.exceptions import (
|
||||
|
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BulkDeleteCssTemplateCommand(BaseCommand):
|
||||
def __init__(self, model_ids: List[int]):
|
||||
def __init__(self, model_ids: list[int]):
|
||||
self._model_ids = model_ids
|
||||
self._models: Optional[List[CssTemplate]] = None
|
||||
self._models: Optional[list[CssTemplate]] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO):
|
|||
model_cls = CssTemplate
|
||||
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None:
|
||||
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
try:
|
||||
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=isinstance-second-argument-not-valid-type
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
@ -37,7 +37,7 @@ class BaseDAO:
|
|||
Base DAO, implement base CRUD sqlalchemy operations
|
||||
"""
|
||||
|
||||
model_cls: Optional[Type[Model]] = None
|
||||
model_cls: Optional[type[Model]] = None
|
||||
"""
|
||||
Child classes need to state the Model class so they don't need to implement basic
|
||||
create, update and delete methods
|
||||
|
@ -75,10 +75,10 @@ class BaseDAO:
|
|||
@classmethod
|
||||
def find_by_ids(
|
||||
cls,
|
||||
model_ids: Union[List[str], List[int]],
|
||||
model_ids: Union[list[str], list[int]],
|
||||
session: Session = None,
|
||||
skip_base_filter: bool = False,
|
||||
) -> List[Model]:
|
||||
) -> list[Model]:
|
||||
"""
|
||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||
"""
|
||||
|
@ -95,7 +95,7 @@ class BaseDAO:
|
|||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def find_all(cls) -> List[Model]:
|
||||
def find_all(cls) -> list[Model]:
|
||||
"""
|
||||
Get all that fit the `base_filter`
|
||||
"""
|
||||
|
@ -121,7 +121,7 @@ class BaseDAO:
|
|||
return query.filter_by(**filter_by).one_or_none()
|
||||
|
||||
@classmethod
|
||||
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
||||
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||
"""
|
||||
Generic for creating models
|
||||
:raises: DAOCreateFailedError
|
||||
|
@ -163,7 +163,7 @@ class BaseDAO:
|
|||
|
||||
@classmethod
|
||||
def update(
|
||||
cls, model: Model, properties: Dict[str, Any], commit: bool = True
|
||||
cls, model: Model, properties: dict[str, Any], commit: bool = True
|
||||
) -> Model:
|
||||
"""
|
||||
Generic update a model
|
||||
|
@ -196,7 +196,7 @@ class BaseDAO:
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def bulk_delete(cls, models: List[Model], commit: bool = True) -> None:
|
||||
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
|
||||
try:
|
||||
for model in models:
|
||||
cls.delete(model, False)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
|
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class BulkDeleteDashboardCommand(BaseCommand):
|
||||
def __init__(self, model_ids: List[int]):
|
||||
def __init__(self, model_ids: list[int]):
|
||||
self._model_ids = model_ids
|
||||
self._models: Optional[List[Dashboard]] = None
|
||||
self._models: Optional[list[Dashboard]] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateDashboardCommand(CreateMixin, BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
|
@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
|
|||
return dashboard
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
||||
role_ids: Optional[List[int]] = self._properties.get("roles")
|
||||
exceptions: list[ValidationError] = []
|
||||
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
role_ids: Optional[list[int]] = self._properties.get("roles")
|
||||
slug: str = self._properties.get("slug", "")
|
||||
|
||||
# Validate slug uniqueness
|
||||
|
|
|
@ -20,7 +20,8 @@ import json
|
|||
import logging
|
||||
import random
|
||||
import string
|
||||
from typing import Any, Dict, Iterator, Optional, Set, Tuple
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Iterator
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str:
|
|||
)
|
||||
|
||||
|
||||
def get_default_position(title: str) -> Dict[str, Any]:
|
||||
def get_default_position(title: str) -> dict[str, Any]:
|
||||
return {
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
|
||||
|
@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]:
|
||||
def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]:
|
||||
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
|
||||
|
||||
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
|
||||
|
@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
|
|||
@staticmethod
|
||||
def _export(
|
||||
model: Dashboard, export_related: bool = True
|
||||
) -> Iterator[Tuple[str, str]]:
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
file_name = get_filename(model.dashboard_title, model.id)
|
||||
file_path = f"dashboards/{file_name}.yaml"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
|
@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand):
|
|||
until it finds one that matches.
|
||||
"""
|
||||
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
|
|
@ -19,7 +19,7 @@ import logging
|
|||
import time
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import make_transient, Session
|
||||
|
@ -83,7 +83,7 @@ def import_chart(
|
|||
def import_dashboard(
|
||||
# pylint: disable=too-many-locals,too-many-statements
|
||||
dashboard_to_import: Dashboard,
|
||||
dataset_id_mapping: Optional[Dict[int, int]] = None,
|
||||
dataset_id_mapping: Optional[dict[int, int]] = None,
|
||||
import_time: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Imports the dashboard from the object to the database.
|
||||
|
@ -97,7 +97,7 @@ def import_dashboard(
|
|||
"""
|
||||
|
||||
def alter_positions(
|
||||
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
|
||||
dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int]
|
||||
) -> None:
|
||||
"""Updates slice_ids in the position json.
|
||||
|
||||
|
@ -166,7 +166,7 @@ def import_dashboard(
|
|||
dashboard_to_import.slug = None
|
||||
|
||||
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
|
||||
old_to_new_slc_id_dict: Dict[int, int] = {}
|
||||
old_to_new_slc_id_dict: dict[int, int] = {}
|
||||
new_timed_refresh_immune_slices = []
|
||||
new_expanded_slices = {}
|
||||
new_filter_scopes = {}
|
||||
|
@ -268,7 +268,7 @@ def import_dashboard(
|
|||
return dashboard_to_import.id # type: ignore
|
||||
|
||||
|
||||
def decode_dashboards(o: Dict[str, Any]) -> Any:
|
||||
def decode_dashboards(o: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Function to be passed into json.loads obj_hook parameter
|
||||
Recreates the dashboard object from a json representation.
|
||||
|
@ -302,7 +302,7 @@ def import_dashboards(
|
|||
data = json.loads(content, object_hook=decode_dashboards)
|
||||
if not data:
|
||||
raise DashboardImportException(_("No data in file"))
|
||||
dataset_id_mapping: Dict[int, int] = {}
|
||||
dataset_id_mapping: dict[int, int] = {}
|
||||
for table in data["datasources"]:
|
||||
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
|
||||
params = json.loads(table.params)
|
||||
|
@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand):
|
|||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(
|
||||
self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any
|
||||
self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any
|
||||
):
|
||||
self.contents = contents
|
||||
self.database_id = database_id
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
from marshmallow import Schema
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
|||
dao = DashboardDAO
|
||||
model_name = "dashboard"
|
||||
prefix = "dashboards/"
|
||||
schemas: Dict[str, Schema] = {
|
||||
schemas: dict[str, Schema] = {
|
||||
"charts/": ImportV1ChartSchema(),
|
||||
"dashboards/": ImportV1DashboardSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
|
@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
|||
# pylint: disable=too-many-branches, too-many-locals
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
||||
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||
) -> None:
|
||||
# discover charts and datasets associated with dashboards
|
||||
chart_uuids: Set[str] = set()
|
||||
dataset_uuids: Set[str] = set()
|
||||
chart_uuids: set[str] = set()
|
||||
dataset_uuids: set[str] = set()
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("dashboards/"):
|
||||
chart_uuids.update(find_chart_uuids(config["position"]))
|
||||
|
@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
|||
dataset_uuids.add(config["dataset_uuid"])
|
||||
|
||||
# discover databases associated with datasets
|
||||
database_uuids: Set[str] = set()
|
||||
database_uuids: set[str] = set()
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
||||
database_uuids.add(config["database_uuid"])
|
||||
|
||||
# import related databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
||||
database = import_database(session, config, overwrite=False)
|
||||
database_ids[str(database.uuid)] = database.id
|
||||
|
||||
# import datasets with the correct parent ref
|
||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
||||
dataset_info: dict[str, dict[str, Any]] = {}
|
||||
for file_name, config in configs.items():
|
||||
if (
|
||||
file_name.startswith("datasets/")
|
||||
|
@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
|||
}
|
||||
|
||||
# import charts with the correct parent ref
|
||||
chart_ids: Dict[str, int] = {}
|
||||
chart_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if (
|
||||
file_name.startswith("charts/")
|
||||
|
@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
|||
).fetchall()
|
||||
|
||||
# import dashboards
|
||||
dashboard_chart_ids: List[Tuple[int, int]] = []
|
||||
dashboard_chart_ids: list[tuple[int, int]] = []
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("dashboards/"):
|
||||
config = update_id_refs(config, chart_ids, dataset_info)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any
|
||||
|
||||
from flask import g
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -32,12 +32,12 @@ logger = logging.getLogger(__name__)
|
|||
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
|
||||
|
||||
|
||||
def find_chart_uuids(position: Dict[str, Any]) -> Set[str]:
|
||||
def find_chart_uuids(position: dict[str, Any]) -> set[str]:
|
||||
return set(build_uuid_to_id_map(position))
|
||||
|
||||
|
||||
def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
|
||||
uuids: Set[str] = set()
|
||||
def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]:
|
||||
uuids: set[str] = set()
|
||||
for native_filter in metadata.get("native_filter_configuration", []):
|
||||
targets = native_filter.get("targets", [])
|
||||
for target in targets:
|
||||
|
@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
|
|||
return uuids
|
||||
|
||||
|
||||
def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
|
||||
def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]:
|
||||
return {
|
||||
child["meta"]["uuid"]: child["meta"]["chartId"]
|
||||
for child in position.values()
|
||||
|
@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
|
|||
|
||||
|
||||
def update_id_refs( # pylint: disable=too-many-locals
|
||||
config: Dict[str, Any],
|
||||
chart_ids: Dict[str, int],
|
||||
dataset_info: Dict[str, Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
config: dict[str, Any],
|
||||
chart_ids: dict[str, int],
|
||||
dataset_info: dict[str, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Update dashboard metadata to use new IDs"""
|
||||
fixed = config.copy()
|
||||
|
||||
|
@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals
|
|||
|
||||
def import_dashboard(
|
||||
session: Session,
|
||||
config: Dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
ignore_permissions: bool = False,
|
||||
) -> Dashboard:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Dashboard] = None
|
||||
|
@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
|||
return dashboard
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
owners_ids: Optional[List[int]] = self._properties.get("owners")
|
||||
roles_ids: Optional[List[int]] = self._properties.get("roles")
|
||||
exceptions: list[ValidationError] = []
|
||||
owners_ids: Optional[list[int]] = self._properties.get("owners")
|
||||
roles_ids: Optional[list[int]] = self._properties.get("roles")
|
||||
slug: Optional[str] = self._properties.get("slug")
|
||||
|
||||
# Validate/populate model exists
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
|
@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO):
|
|||
return dashboard
|
||||
|
||||
@staticmethod
|
||||
def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
|
||||
def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]:
|
||||
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
|
||||
return dashboard.datasets_trimmed_for_slices()
|
||||
|
||||
@staticmethod
|
||||
def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]:
|
||||
def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]:
|
||||
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
|
||||
|
||||
@staticmethod
|
||||
|
@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO):
|
|||
return model
|
||||
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
|
||||
def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
# bulk delete, first delete related data
|
||||
if models:
|
||||
|
@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO):
|
|||
@staticmethod
|
||||
def set_dash_metadata( # pylint: disable=too-many-locals
|
||||
dashboard: Dashboard,
|
||||
data: Dict[Any, Any],
|
||||
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
|
||||
data: dict[Any, Any],
|
||||
old_to_new_slice_ids: Optional[dict[int, int]] = None,
|
||||
commit: bool = False,
|
||||
) -> Dashboard:
|
||||
new_filter_scopes = {}
|
||||
|
@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO):
|
|||
if "filter_scopes" in data:
|
||||
# replace filter_id and immune ids from old slice id to new slice id:
|
||||
# and remove slice ids that are not in dash anymore
|
||||
slc_id_dict: Dict[int, int] = {}
|
||||
slc_id_dict: dict[int, int] = {}
|
||||
if old_to_new_slice_ids:
|
||||
slc_id_dict = {
|
||||
old: new
|
||||
|
@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO):
|
|||
return dashboard
|
||||
|
||||
@staticmethod
|
||||
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
|
||||
def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
|
||||
ids = [dash.id for dash in dashboards]
|
||||
return [
|
||||
star.obj_id
|
||||
|
@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO):
|
|||
|
||||
@classmethod
|
||||
def copy_dashboard(
|
||||
cls, original_dash: Dashboard, data: Dict[str, Any]
|
||||
cls, original_dash: Dashboard, data: dict[str, Any]
|
||||
) -> Dashboard:
|
||||
dash = Dashboard()
|
||||
dash.owners = [g.user] if g.user else []
|
||||
|
@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO):
|
|||
dash.css = data.get("css")
|
||||
|
||||
metadata = json.loads(data["json_metadata"])
|
||||
old_to_new_slice_ids: Dict[int, int] = {}
|
||||
old_to_new_slice_ids: dict[int, int] = {}
|
||||
if data.get("duplicate_slices"):
|
||||
# Duplicating slices as well, mapping old ids to new ones
|
||||
for slc in original_dash.slices:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class CreateFilterSetCommand(BaseFilterSetCommand):
|
||||
# pylint: disable=C0103
|
||||
def __init__(self, dashboard_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, dashboard_id: int, data: dict[str, Any]):
|
||||
super().__init__(dashboard_id)
|
||||
self._properties = data.copy()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateFilterSetCommand(BaseFilterSetCommand):
|
||||
def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]):
|
||||
super().__init__(dashboard_id)
|
||||
self._filter_set_id = filter_set_id
|
||||
self._properties = data.copy()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO):
|
|||
model_cls = FilterSet
|
||||
|
||||
@classmethod
|
||||
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
||||
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||
if cls.model_cls is None:
|
||||
raise DAOConfigError()
|
||||
model = FilterSet()
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, cast, Dict, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from marshmallow import fields, post_load, Schema, ValidationError
|
||||
from marshmallow.validate import Length, OneOf
|
||||
|
@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema):
|
|||
@post_load
|
||||
def validate(
|
||||
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
||||
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
|
||||
raise ValidationError("owner_id is mandatory when owner_type is User")
|
||||
return cast(Dict[str, Any], data)
|
||||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
class FilterSetPutSchema(FilterSetSchema):
|
||||
|
@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema):
|
|||
@post_load
|
||||
def validate( # pylint: disable=unused-argument
|
||||
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if JSON_METADATA_FIELD in data:
|
||||
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
||||
return cast(Dict[str, Any], data)
|
||||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None:
|
||||
def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None:
|
||||
if first_field in data and second_field not in data:
|
||||
raise ValidationError(
|
||||
"{} must be included alongside {}".format(first_field, second_field)
|
||||
f"{first_field} must be included alongside {second_field}"
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from flask import Response
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
|
@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi):
|
|||
resource_name = "dashboard"
|
||||
openapi_spec_tag = "Dashboard Filter State"
|
||||
|
||||
def get_create_command(self) -> Type[CreateFilterStateCommand]:
|
||||
def get_create_command(self) -> type[CreateFilterStateCommand]:
|
||||
return CreateFilterStateCommand
|
||||
|
||||
def get_update_command(self) -> Type[UpdateFilterStateCommand]:
|
||||
def get_update_command(self) -> type[UpdateFilterStateCommand]:
|
||||
return UpdateFilterStateCommand
|
||||
|
||||
def get_get_command(self) -> Type[GetFilterStateCommand]:
|
||||
def get_get_command(self) -> type[GetFilterStateCommand]:
|
||||
return GetFilterStateCommand
|
||||
|
||||
def get_delete_command(self) -> Type[DeleteFilterStateCommand]:
|
||||
def get_delete_command(self) -> type[DeleteFilterStateCommand]:
|
||||
return DeleteFilterStateCommand
|
||||
|
||||
@expose("/<int:pk>/filter_state", methods=("POST",))
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
|
||||
class DashboardPermalinkState(TypedDict):
|
||||
dataMask: Optional[Dict[str, Any]]
|
||||
activeTabs: Optional[List[str]]
|
||||
dataMask: Optional[dict[str, Any]]
|
||||
activeTabs: Optional[list[str]]
|
||||
anchor: Optional[str]
|
||||
urlParams: Optional[List[Tuple[str, str]]]
|
||||
urlParams: Optional[list[tuple[str, str]]]
|
||||
|
||||
|
||||
class DashboardPermalinkValue(TypedDict):
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from marshmallow import fields, post_load, pre_load, Schema
|
||||
from marshmallow.validate import Length, ValidationError
|
||||
|
@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema):
|
|||
@pre_load
|
||||
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
data: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Remove ``show_native_filters`` from the JSON metadata.
|
||||
|
||||
|
@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema):
|
|||
class BaseDashboardSchema(Schema):
|
||||
# pylint: disable=no-self-use,unused-argument
|
||||
@post_load
|
||||
def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
||||
def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
||||
if data.get("slug"):
|
||||
data["slug"] = data["slug"].strip()
|
||||
data["slug"] = data["slug"].replace(" ", "-")
|
||||
|
|
|
@ -19,7 +19,7 @@ import json
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
from typing import Any, cast, Optional
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
from flask import request, Response, send_file
|
||||
|
@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
|
||||
preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", [])
|
||||
available_databases = []
|
||||
for engine_spec, drivers in get_available_engine_specs().items():
|
||||
if not drivers:
|
||||
continue
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
"name": engine_spec.engine_name,
|
||||
"engine": engine_spec.engine,
|
||||
"available_drivers": sorted(drivers),
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"]
|
|||
|
||||
|
||||
class CreateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
|
@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
return database
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if not sqlalchemy_uri:
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, Tuple
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_extra(extra_payload: str) -> Dict[str, Any]:
|
||||
def parse_extra(extra_payload: str) -> dict[str, Any]:
|
||||
try:
|
||||
extra = json.loads(extra_payload)
|
||||
except json.decoder.JSONDecodeError:
|
||||
|
@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand):
|
|||
@staticmethod
|
||||
def _export(
|
||||
model: Database, export_related: bool = True
|
||||
) -> Iterator[Tuple[str, str]]:
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
|
||||
file_path = f"databases/{db_file_name}.yaml"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
|
@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand):
|
|||
until it finds one that matches.
|
||||
"""
|
||||
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from marshmallow import Schema
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand):
|
|||
dao = DatabaseDAO
|
||||
model_name = "database"
|
||||
prefix = "databases/"
|
||||
schemas: Dict[str, Schema] = {
|
||||
schemas: dict[str, Schema] = {
|
||||
"databases/": ImportV1DatabaseSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
}
|
||||
|
@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand):
|
|||
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
||||
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||
) -> None:
|
||||
# first import databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/"):
|
||||
database = import_database(session, config, overwrite=overwrite)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
@ -28,7 +28,7 @@ from superset.models.core import Database
|
|||
|
||||
def import_database(
|
||||
session: Session,
|
||||
config: Dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
ignore_permissions: bool = False,
|
||||
) -> Database:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, cast, Dict
|
||||
from typing import Any, cast
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand):
|
|||
self._schema_name = schema_name
|
||||
self._force = force
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
def run(self) -> dict[str, Any]:
|
||||
self.validate()
|
||||
try:
|
||||
tables = security_manager.get_datasources_accessible_by_user(
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as _
|
||||
|
@ -64,7 +64,7 @@ def get_log_connection_action(
|
|||
|
||||
|
||||
class TestConnectionDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
# Update database schema permissions
|
||||
new_schemas: List[str] = []
|
||||
new_schemas: list[str] = []
|
||||
|
||||
for schema in schemas:
|
||||
old_view_menu_name = security_manager.get_schema_perm(
|
||||
|
@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
chart.schema_perm = new_view_menu_name
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import json
|
||||
from contextlib import closing
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_babel import gettext as __
|
||||
|
||||
|
@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
|
|||
|
||||
|
||||
class ValidateDatabaseParametersCommand(BaseCommand):
|
||||
def __init__(self, properties: Dict[str, Any]):
|
||||
def __init__(self, properties: dict[str, Any]):
|
||||
self._properties = properties.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __
|
||||
|
@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ValidateSQLCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
self._validator: Optional[Type[BaseSQLValidator]] = None
|
||||
self._validator: Optional[type[BaseSQLValidator]] = None
|
||||
|
||||
def run(self) -> List[Dict[str, Any]]:
|
||||
def run(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Validates a SQL statement
|
||||
|
||||
|
@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand):
|
|||
if not validators_by_engine or spec.engine not in validators_by_engine:
|
||||
raise NoValidatorConfigFoundError(
|
||||
SupersetError(
|
||||
message=__(
|
||||
"no SQL validator is configured for {}".format(spec.engine)
|
||||
),
|
||||
message=__(f"no SQL validator is configured for {spec.engine}"),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
),
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
|
@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO):
|
|||
def update(
|
||||
cls,
|
||||
model: Database,
|
||||
properties: Dict[str, Any],
|
||||
properties: dict[str, Any],
|
||||
commit: bool = True,
|
||||
) -> Database:
|
||||
"""
|
||||
|
@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
|
||||
def get_related_objects(cls, database_id: int) -> dict[str, Any]:
|
||||
database: Any = cls.find_by_id(database_id)
|
||||
datasets = database.tables
|
||||
dataset_ids = [dataset.id for dataset in datasets]
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Set
|
||||
from typing import Any
|
||||
|
||||
from flask import g
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
@ -30,7 +30,7 @@ from superset.views.base import BaseFilter
|
|||
|
||||
def can_access_databases(
|
||||
view_menu_name: str,
|
||||
) -> Set[str]:
|
||||
) -> set[str]:
|
||||
return {
|
||||
security_manager.unpack_database_and_schema(vm).database
|
||||
for vm in security_manager.user_view_menu_names(view_menu_name)
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
@pre_load
|
||||
def build_sqlalchemy_uri(
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, data: dict[str, Any], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build SQLAlchemy URI from separate parameters.
|
||||
|
||||
|
@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
def rename_encrypted_extra(
|
||||
self: Schema,
|
||||
data: Dict[str, Any],
|
||||
data: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
|
||||
|
||||
|
@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema):
|
|||
class ImportV1DatabaseExtraSchema(Schema):
|
||||
@pre_load
|
||||
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, data: dict[str, Any], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fixes for ``schemas_allowed_for_csv_upload``.
|
||||
"""
|
||||
|
@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema):
|
|||
class ImportV1DatabaseSchema(Schema):
|
||||
@pre_load
|
||||
def fix_allow_csv_upload(
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, data: dict[str, Any], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fix for ``allow_csv_upload`` .
|
||||
"""
|
||||
|
@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema):
|
|||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
|
||||
@validates_schema
|
||||
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""If sqlalchemy_uri has a masked password, password is required"""
|
||||
uuid = data["uuid"]
|
||||
existing = db.session.query(Database).filter_by(uuid=uuid).first()
|
||||
|
@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema):
|
|||
|
||||
@validates_schema
|
||||
def validate_ssh_tunnel_credentials(
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
self, data: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""If ssh_tunnel has a masked credentials, credentials are required"""
|
||||
uuid = data["uuid"]
|
||||
|
@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema):
|
|||
# or there're times where it's masked.
|
||||
# If both are masked, we need to return a list of errors
|
||||
# so the UI ask for both fields at the same time if needed
|
||||
exception_messages: List[str] = []
|
||||
exception_messages: list[str] = []
|
||||
if private_key is None or private_key == PASSWORD_MASK:
|
||||
# If we get here we need to ask for the private key
|
||||
exception_messages.append(
|
||||
|
@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict):
|
|||
pass
|
||||
|
||||
|
||||
def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore
|
||||
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
|
||||
ret = {}
|
||||
if isinstance(field, EncryptedField):
|
||||
if self.openapi_version.major > 2:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, database_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, database_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._properties["database_id"] = database_id
|
||||
|
||||
|
@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand):
|
|||
def validate(self) -> None:
|
||||
# TODO(hughhh): check to make sure the server port is not localhost
|
||||
# using the config.SSH_TUNNEL_MANAGER
|
||||
exceptions: List[ValidationError] = []
|
||||
exceptions: list[ValidationError] = []
|
||||
database_id: Optional[int] = self._properties.get("database_id")
|
||||
server_address: Optional[str] = self._properties.get("server_address")
|
||||
server_port: Optional[int] = self._properties.get("server_port")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO):
|
|||
def update(
|
||||
cls,
|
||||
model: SSHTunnel,
|
||||
properties: Dict[str, Any],
|
||||
properties: dict[str, Any],
|
||||
commit: bool = True,
|
||||
) -> SSHTunnel:
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
|
@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
|||
]
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
def data(self) -> dict[str, Any]:
|
||||
output = {
|
||||
"id": self.id,
|
||||
"server_address": self.server_address,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue