mirror of https://github.com/apache/superset.git
chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
This commit is contained in:
parent
7d7ce63970
commit
a4d5d7c6b9
|
@ -15,14 +15,28 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/MarcoGorelli/auto-walrus
|
||||||
|
rev: v0.2.2
|
||||||
|
hooks:
|
||||||
|
- id: auto-walrus
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.4.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args:
|
||||||
|
- --py39-plus
|
||||||
|
- repo: https://github.com/hadialqattan/pycln
|
||||||
|
rev: v2.1.2
|
||||||
|
hooks:
|
||||||
|
- id: pycln
|
||||||
|
args:
|
||||||
|
- --disable-all-dunder-policy
|
||||||
|
- --exclude=superset/config.py
|
||||||
|
- --extend-exclude=tests/integration_tests/superset_test_config.*.py
|
||||||
- repo: https://github.com/PyCQA/isort
|
- repo: https://github.com/PyCQA/isort
|
||||||
rev: 5.12.0
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/MarcoGorelli/auto-walrus
|
|
||||||
rev: v0.2.2
|
|
||||||
hooks:
|
|
||||||
- id: auto-walrus
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.3.0
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -17,8 +17,9 @@ import csv as lib_csv
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from click.core import Context
|
from click.core import Context
|
||||||
|
@ -67,15 +68,15 @@ class GitChangeLog:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
version: str,
|
version: str,
|
||||||
logs: List[GitLog],
|
logs: list[GitLog],
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
risk: Optional[bool] = False,
|
risk: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._version = version
|
self._version = version
|
||||||
self._logs = logs
|
self._logs = logs
|
||||||
self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
|
self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
|
||||||
self._github_login_cache: Dict[str, Optional[str]] = {}
|
self._github_login_cache: dict[str, Optional[str]] = {}
|
||||||
self._github_prs: Dict[int, Any] = {}
|
self._github_prs: dict[int, Any] = {}
|
||||||
self._wait = 10
|
self._wait = 10
|
||||||
github_token = access_token or os.environ.get("GITHUB_TOKEN")
|
github_token = access_token or os.environ.get("GITHUB_TOKEN")
|
||||||
self._github = Github(github_token)
|
self._github = Github(github_token)
|
||||||
|
@ -126,7 +127,7 @@ class GitChangeLog:
|
||||||
"superset/migrations/versions/" in file.filename for file in commit.files
|
"superset/migrations/versions/" in file.filename for file in commit.files
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
|
def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
|
||||||
pr_number = git_log.pr_number
|
pr_number = git_log.pr_number
|
||||||
if pr_number:
|
if pr_number:
|
||||||
detail = self._pr_logs_with_details.get(pr_number)
|
detail = self._pr_logs_with_details.get(pr_number)
|
||||||
|
@ -156,7 +157,7 @@ class GitChangeLog:
|
||||||
|
|
||||||
return detail
|
return detail
|
||||||
|
|
||||||
def _is_risk_pull_request(self, labels: List[Any]) -> bool:
|
def _is_risk_pull_request(self, labels: list[Any]) -> bool:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
|
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
|
||||||
if risk_label is not None:
|
if risk_label is not None:
|
||||||
|
@ -174,8 +175,8 @@ class GitChangeLog:
|
||||||
|
|
||||||
def _parse_change_log(
|
def _parse_change_log(
|
||||||
self,
|
self,
|
||||||
changelog: Dict[str, str],
|
changelog: dict[str, str],
|
||||||
pr_info: Dict[str, str],
|
pr_info: dict[str, str],
|
||||||
github_login: str,
|
github_login: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
formatted_pr = (
|
formatted_pr = (
|
||||||
|
@ -227,7 +228,7 @@ class GitChangeLog:
|
||||||
result += f"**{key}** {changelog[key]}\n"
|
result += f"**{key}** {changelog[key]}\n"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
def __iter__(self) -> Iterator[dict[str, Any]]:
|
||||||
for log in self._logs:
|
for log in self._logs:
|
||||||
yield {
|
yield {
|
||||||
"pr_number": log.pr_number,
|
"pr_number": log.pr_number,
|
||||||
|
@ -250,20 +251,20 @@ class GitLogs:
|
||||||
|
|
||||||
def __init__(self, git_ref: str) -> None:
|
def __init__(self, git_ref: str) -> None:
|
||||||
self._git_ref = git_ref
|
self._git_ref = git_ref
|
||||||
self._logs: List[GitLog] = []
|
self._logs: list[GitLog] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def git_ref(self) -> str:
|
def git_ref(self) -> str:
|
||||||
return self._git_ref
|
return self._git_ref
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logs(self) -> List[GitLog]:
|
def logs(self) -> list[GitLog]:
|
||||||
return self._logs
|
return self._logs
|
||||||
|
|
||||||
def fetch(self) -> None:
|
def fetch(self) -> None:
|
||||||
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
|
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
|
||||||
|
|
||||||
def diff(self, git_logs: "GitLogs") -> List[GitLog]:
|
def diff(self, git_logs: "GitLogs") -> list[GitLog]:
|
||||||
return [log for log in git_logs.logs if log not in self._logs]
|
return [log for log in git_logs.logs if log not in self._logs]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -284,7 +285,7 @@ class GitLogs:
|
||||||
print(f"Could not checkout {git_ref}")
|
print(f"Could not checkout {git_ref}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def _git_logs(self) -> List[str]:
|
def _git_logs(self) -> list[str]:
|
||||||
# let's get current git ref so we can revert it back
|
# let's get current git ref so we can revert it back
|
||||||
current_git_ref = self._git_get_current_head()
|
current_git_ref = self._git_get_current_head()
|
||||||
self._git_checkout(self._git_ref)
|
self._git_checkout(self._git_ref)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from click.core import Context
|
from click.core import Context
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ PROJECT_MODULE = "superset"
|
||||||
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
|
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
|
||||||
|
|
||||||
|
|
||||||
def string_comma_to_list(message: str) -> List[str]:
|
def string_comma_to_list(message: str) -> list[str]:
|
||||||
if not message:
|
if not message:
|
||||||
return []
|
return []
|
||||||
return [element.strip() for element in message.split(",")]
|
return [element.strip() for element in message.split(",")]
|
||||||
|
@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str:
|
||||||
return template.render(kwargs)
|
return template.render(kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BaseParameters(object):
|
class BaseParameters:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
version: str,
|
version: str,
|
||||||
|
@ -60,7 +60,7 @@ class BaseParameters(object):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.version = version
|
self.version = version
|
||||||
self.version_rc = version_rc
|
self.version_rc = version_rc
|
||||||
self.template_arguments: Dict[str, Any] = {}
|
self.template_arguments: dict[str, Any] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Apache Credentials: {self.version}/{self.version_rc}"
|
return f"Apache Credentials: {self.version}/{self.version_rc}"
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from cachelib.file import FileSystemCache
|
from cachelib.file import FileSystemCache
|
||||||
|
@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
|
||||||
error_msg = "The environment variable {} was missing, abort...".format(
|
error_msg = "The environment variable {} was missing, abort...".format(
|
||||||
var_name
|
var_name
|
||||||
)
|
)
|
||||||
raise EnvironmentError(error_msg)
|
raise OSError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
|
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
|
||||||
|
@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT")
|
||||||
DATABASE_DB = get_env_variable("DATABASE_DB")
|
DATABASE_DB = get_env_variable("DATABASE_DB")
|
||||||
|
|
||||||
# The SQLAlchemy connection string.
|
# The SQLAlchemy connection string.
|
||||||
SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
|
SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
|
||||||
DATABASE_DIALECT,
|
DATABASE_DIALECT,
|
||||||
DATABASE_USER,
|
DATABASE_USER,
|
||||||
DATABASE_PASSWORD,
|
DATABASE_PASSWORD,
|
||||||
|
@ -80,7 +79,7 @@ CACHE_CONFIG = {
|
||||||
DATA_CACHE_CONFIG = CACHE_CONFIG
|
DATA_CACHE_CONFIG = CACHE_CONFIG
|
||||||
|
|
||||||
|
|
||||||
class CeleryConfig(object):
|
class CeleryConfig:
|
||||||
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
|
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
|
||||||
imports = ("superset.sql_lab",)
|
imports = ("superset.sql_lab",)
|
||||||
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
|
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
|
||||||
|
|
|
@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
|
||||||
from inspect import getsource
|
from inspect import getsource
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Dict, List, Set, Type
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module) # type: ignore
|
spec.loader.exec_module(module) # type: ignore
|
||||||
return module
|
return module
|
||||||
raise Exception(
|
raise Exception(f"No module spec found in location: `{str(filepath)}`")
|
||||||
"No module spec found in location: `{path}`".format(path=str(filepath))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_modified_tables(module: ModuleType) -> Set[str]:
|
def extract_modified_tables(module: ModuleType) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Extract the tables being modified by a migration script.
|
Extract the tables being modified by a migration script.
|
||||||
|
|
||||||
|
@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
|
||||||
actually traversing the AST.
|
actually traversing the AST.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tables: Set[str] = set()
|
tables: set[str] = set()
|
||||||
for function in {"upgrade", "downgrade"}:
|
for function in {"upgrade", "downgrade"}:
|
||||||
source = getsource(getattr(module, function))
|
source = getsource(getattr(module, function))
|
||||||
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
|
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
|
||||||
|
@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
|
||||||
def find_models(module: ModuleType) -> List[Type[Model]]:
|
def find_models(module: ModuleType) -> list[type[Model]]:
|
||||||
"""
|
"""
|
||||||
Find all models in a migration script.
|
Find all models in a migration script.
|
||||||
"""
|
"""
|
||||||
models: List[Type[Model]] = []
|
models: list[type[Model]] = []
|
||||||
tables = extract_modified_tables(module)
|
tables = extract_modified_tables(module)
|
||||||
|
|
||||||
# add models defined explicitly in the migration script
|
# add models defined explicitly in the migration script
|
||||||
|
@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
|
||||||
sorter: TopologicalSorter[Any] = TopologicalSorter()
|
sorter: TopologicalSorter[Any] = TopologicalSorter()
|
||||||
for model in models:
|
for model in models:
|
||||||
inspector = inspect(model)
|
inspector = inspect(model)
|
||||||
dependent_tables: List[str] = []
|
dependent_tables: list[str] = []
|
||||||
for column in inspector.columns.values():
|
for column in inspector.columns.values():
|
||||||
for foreign_key in column.foreign_keys:
|
for foreign_key in column.foreign_keys:
|
||||||
if foreign_key.column.table.name != model.__tablename__:
|
if foreign_key.column.table.name != model.__tablename__:
|
||||||
|
@ -174,7 +172,7 @@ def main(
|
||||||
|
|
||||||
print("\nIdentifying models used in the migration:")
|
print("\nIdentifying models used in the migration:")
|
||||||
models = find_models(module)
|
models = find_models(module)
|
||||||
model_rows: Dict[Type[Model], int] = {}
|
model_rows: dict[type[Model], int] = {}
|
||||||
for model in models:
|
for model in models:
|
||||||
rows = session.query(model).count()
|
rows = session.query(model).count()
|
||||||
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
|
||||||
|
@ -182,7 +180,7 @@ def main(
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
print("Benchmarking migration")
|
print("Benchmarking migration")
|
||||||
results: Dict[str, float] = {}
|
results: dict[str, float] = {}
|
||||||
start = time.time()
|
start = time.time()
|
||||||
upgrade(revision=revision)
|
upgrade(revision=revision)
|
||||||
duration = time.time() - start
|
duration = time.time() - start
|
||||||
|
@ -190,14 +188,14 @@ def main(
|
||||||
print(f"Migration on current DB took: {duration:.2f} seconds")
|
print(f"Migration on current DB took: {duration:.2f} seconds")
|
||||||
|
|
||||||
min_entities = 10
|
min_entities = 10
|
||||||
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
|
new_models: dict[type[Model], list[Model]] = defaultdict(list)
|
||||||
while min_entities <= limit:
|
while min_entities <= limit:
|
||||||
downgrade(revision=down_revision)
|
downgrade(revision=down_revision)
|
||||||
print(f"Running with at least {min_entities} entities of each model")
|
print(f"Running with at least {min_entities} entities of each model")
|
||||||
for model in models:
|
for model in models:
|
||||||
missing = min_entities - model_rows[model]
|
missing = min_entities - model_rows[model]
|
||||||
if missing > 0:
|
if missing > 0:
|
||||||
entities: List[Model] = []
|
entities: list[Model] = []
|
||||||
print(f"- Adding {missing} entities to the {model.__name__} model")
|
print(f"- Adding {missing} entities to the {model.__name__} model")
|
||||||
bar = ChargingBar("Processing", max=missing)
|
bar = ChargingBar("Processing", max=missing)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -33,13 +33,13 @@ Example:
|
||||||
./cancel_github_workflows.py 1024 --include-last
|
./cancel_github_workflows.py 1024 --include-last
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
from collections.abc import Iterable, Iterator
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
from click.exceptions import ClickException
|
from click.exceptions import ClickException
|
||||||
from dateutil import parser
|
from dateutil import parser
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
github_token = os.environ.get("GITHUB_TOKEN")
|
github_token = os.environ.get("GITHUB_TOKEN")
|
||||||
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
||||||
|
@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
|
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
resp = requests.request(
|
resp = requests.request(
|
||||||
method,
|
method,
|
||||||
f"https://api.github.com/{endpoint.lstrip('/')}",
|
f"https://api.github.com/{endpoint.lstrip('/')}",
|
||||||
|
@ -61,8 +61,8 @@ def request(
|
||||||
|
|
||||||
def list_runs(
|
def list_runs(
|
||||||
repo: str,
|
repo: str,
|
||||||
params: Optional[Dict[str, str]] = None,
|
params: Optional[dict[str, str]] = None,
|
||||||
) -> Iterator[Dict[str, Any]]:
|
) -> Iterator[dict[str, Any]]:
|
||||||
"""List all github workflow runs.
|
"""List all github workflow runs.
|
||||||
Returns:
|
Returns:
|
||||||
An iterator that will iterate through all pages of matching runs."""
|
An iterator that will iterate through all pages of matching runs."""
|
||||||
|
@ -77,16 +77,15 @@ def list_runs(
|
||||||
params={**params, "per_page": 100, "page": page},
|
params={**params, "per_page": 100, "page": page},
|
||||||
)
|
)
|
||||||
total_count = result["total_count"]
|
total_count = result["total_count"]
|
||||||
for item in result["workflow_runs"]:
|
yield from result["workflow_runs"]
|
||||||
yield item
|
|
||||||
page += 1
|
page += 1
|
||||||
|
|
||||||
|
|
||||||
def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
|
def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
|
||||||
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
|
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
|
||||||
|
|
||||||
|
|
||||||
def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
|
def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
|
||||||
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
|
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ def get_runs(
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
statuses: Iterable[str] = ("queued", "in_progress"),
|
statuses: Iterable[str] = ("queued", "in_progress"),
|
||||||
events: Iterable[str] = ("pull_request", "push"),
|
events: Iterable[str] = ("pull_request", "push"),
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Get workflow runs associated with the given branch"""
|
"""Get workflow runs associated with the given branch"""
|
||||||
return [
|
return [
|
||||||
item
|
item
|
||||||
|
@ -108,7 +107,7 @@ def get_runs(
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def print_commit(commit: Dict[str, Any], branch: str) -> None:
|
def print_commit(commit: dict[str, Any], branch: str) -> None:
|
||||||
"""Print out commit message for verification"""
|
"""Print out commit message for verification"""
|
||||||
indented_message = " \n".join(commit["message"].split("\n"))
|
indented_message = " \n".join(commit["message"].split("\n"))
|
||||||
date_str = (
|
date_str = (
|
||||||
|
@ -155,7 +154,7 @@ Date: {date_str}
|
||||||
def cancel_github_workflows(
|
def cancel_github_workflows(
|
||||||
branch_or_pull: Optional[str],
|
branch_or_pull: Optional[str],
|
||||||
repo: str,
|
repo: str,
|
||||||
event: List[str],
|
event: list[str],
|
||||||
include_last: bool,
|
include_last: bool,
|
||||||
include_running: bool,
|
include_running: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -24,7 +24,7 @@ def cleanup_permissions() -> None:
|
||||||
pvms = security_manager.get_session.query(
|
pvms = security_manager.get_session.query(
|
||||||
security_manager.permissionview_model
|
security_manager.permissionview_model
|
||||||
).all()
|
).all()
|
||||||
print("# of permission view menus is: {}".format(len(pvms)))
|
print(f"# of permission view menus is: {len(pvms)}")
|
||||||
pvms_dict = defaultdict(list)
|
pvms_dict = defaultdict(list)
|
||||||
for pvm in pvms:
|
for pvm in pvms:
|
||||||
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
|
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
|
||||||
|
@ -43,7 +43,7 @@ def cleanup_permissions() -> None:
|
||||||
pvms = security_manager.get_session.query(
|
pvms = security_manager.get_session.query(
|
||||||
security_manager.permissionview_model
|
security_manager.permissionview_model
|
||||||
).all()
|
).all()
|
||||||
print("Stage 1: # of permission view menus is: {}".format(len(pvms)))
|
print(f"Stage 1: # of permission view menus is: {len(pvms)}")
|
||||||
|
|
||||||
# 2. Clean up None permissions or view menus
|
# 2. Clean up None permissions or view menus
|
||||||
pvms = security_manager.get_session.query(
|
pvms = security_manager.get_session.query(
|
||||||
|
@ -57,7 +57,7 @@ def cleanup_permissions() -> None:
|
||||||
pvms = security_manager.get_session.query(
|
pvms = security_manager.get_session.query(
|
||||||
security_manager.permissionview_model
|
security_manager.permissionview_model
|
||||||
).all()
|
).all()
|
||||||
print("Stage 2: # of permission view menus is: {}".format(len(pvms)))
|
print(f"Stage 2: # of permission view menus is: {len(pvms)}")
|
||||||
|
|
||||||
# 3. Delete empty permission view menus from roles
|
# 3. Delete empty permission view menus from roles
|
||||||
roles = security_manager.get_session.query(security_manager.role_model).all()
|
roles = security_manager.get_session.query(security_manager.role_model).all()
|
||||||
|
|
6
setup.py
6
setup.py
|
@ -14,21 +14,19 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
||||||
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
|
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
|
||||||
|
|
||||||
with open(PACKAGE_JSON, "r") as package_file:
|
with open(PACKAGE_JSON) as package_file:
|
||||||
version_string = json.load(package_file)["version"]
|
version_string = json.load(package_file)["version"]
|
||||||
|
|
||||||
with io.open("README.md", "r", encoding="utf-8") as f:
|
with open("README.md", encoding="utf-8") as f:
|
||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
|
||||||
|
|
||||||
# Make this return a single clause
|
# Make this return a single clause
|
||||||
def cidr_translate_filter_func(
|
def cidr_translate_filter_func(
|
||||||
col: Column, operator: FilterOperator, values: List[Any]
|
col: Column, operator: FilterOperator, values: list[Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Convert a passed in column, FilterOperator and
|
Convert a passed in column, FilterOperator and
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from superset.advanced_data_type.types import (
|
||||||
)
|
)
|
||||||
from superset.utils.core import FilterOperator, FilterStringOperators
|
from superset.utils.core import FilterOperator, FilterStringOperators
|
||||||
|
|
||||||
port_conversion_dict: Dict[str, List[int]] = {
|
port_conversion_dict: dict[str, list[int]] = {
|
||||||
"http": [80],
|
"http": [80],
|
||||||
"ssh": [22],
|
"ssh": [22],
|
||||||
"https": [443],
|
"https": [443],
|
||||||
|
@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo
|
||||||
|
|
||||||
|
|
||||||
def port_translate_filter_func(
|
def port_translate_filter_func(
|
||||||
col: Column, operator: FilterOperator, values: List[Any]
|
col: Column, operator: FilterOperator, values: list[Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Convert a passed in column, FilterOperator
|
Convert a passed in column, FilterOperator
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, List, Optional, TypedDict, Union
|
from typing import Any, Callable, Optional, TypedDict, Union
|
||||||
|
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column
|
||||||
from sqlalchemy.sql.expression import BinaryExpression
|
from sqlalchemy.sql.expression import BinaryExpression
|
||||||
|
@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
advanced_data_type: str
|
advanced_data_type: str
|
||||||
values: List[
|
values: list[
|
||||||
Union[FilterValues, None]
|
Union[FilterValues, None]
|
||||||
] # unparsed value (usually text when passed from text box)
|
] # unparsed value (usually text when passed from text box)
|
||||||
|
|
||||||
|
@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_message: Optional[str]
|
error_message: Optional[str]
|
||||||
values: List[Any] # parsed value (can be any value)
|
values: list[Any] # parsed value (can be any value)
|
||||||
display_value: str # The string representation of the parsed values
|
display_value: str # The string representation of the parsed values
|
||||||
valid_filter_operators: List[FilterStringOperators]
|
valid_filter_operators: list[FilterStringOperators]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -54,6 +54,6 @@ class AdvancedDataType:
|
||||||
|
|
||||||
verbose_name: str
|
verbose_name: str
|
||||||
description: str
|
description: str
|
||||||
valid_data_types: List[str]
|
valid_data_types: list[str]
|
||||||
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
|
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
|
||||||
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]
|
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
|
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
|
||||||
|
@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
|
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
|
||||||
layer_id: int, rison_parameters: Dict[str, Any]
|
layer_id: int, rison_parameters: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
if "filters" not in rison_parameters:
|
if "filters" not in rison_parameters:
|
||||||
rison_parameters["filters"] = []
|
rison_parameters["filters"] = []
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from superset.annotation_layers.annotations.commands.exceptions import (
|
from superset.annotation_layers.annotations.commands.exceptions import (
|
||||||
AnnotationBulkDeleteFailedError,
|
AnnotationBulkDeleteFailedError,
|
||||||
|
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkDeleteAnnotationCommand(BaseCommand):
|
class BulkDeleteAnnotationCommand(BaseCommand):
|
||||||
def __init__(self, model_ids: List[int]):
|
def __init__(self, model_ids: list[int]):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[Annotation]] = None
|
self._models: Optional[list[Annotation]] = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateAnnotationCommand(BaseCommand):
|
class CreateAnnotationCommand(BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
|
@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand):
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
layer_id: Optional[int] = self._properties.get("layer")
|
layer_id: Optional[int] = self._properties.get("layer")
|
||||||
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
|
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
|
||||||
end_dttm: Optional[datetime] = self._properties.get("end_dttm")
|
end_dttm: Optional[datetime] = self._properties.get("end_dttm")
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateAnnotationCommand(BaseCommand):
|
class UpdateAnnotationCommand(BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[Annotation] = None
|
self._model: Optional[Annotation] = None
|
||||||
|
@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand):
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
layer_id: Optional[int] = self._properties.get("layer")
|
layer_id: Optional[int] = self._properties.get("layer")
|
||||||
short_descr: str = self._properties.get("short_descr", "")
|
short_descr: str = self._properties.get("short_descr", "")
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO):
|
||||||
model_cls = Annotation
|
model_cls = Annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None:
|
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
|
||||||
item_ids = [model.id for model in models] if models else []
|
item_ids = [model.id for model in models] if models else []
|
||||||
try:
|
try:
|
||||||
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(
|
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from superset.annotation_layers.commands.exceptions import (
|
from superset.annotation_layers.commands.exceptions import (
|
||||||
AnnotationLayerBulkDeleteFailedError,
|
AnnotationLayerBulkDeleteFailedError,
|
||||||
|
@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkDeleteAnnotationLayerCommand(BaseCommand):
|
class BulkDeleteAnnotationLayerCommand(BaseCommand):
|
||||||
def __init__(self, model_ids: List[int]):
|
def __init__(self, model_ids: list[int]):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[AnnotationLayer]] = None
|
self._models: Optional[list[AnnotationLayer]] = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateAnnotationLayerCommand(BaseCommand):
|
class CreateAnnotationLayerCommand(BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
|
@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
|
||||||
return annotation_layer
|
return annotation_layer
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
|
|
||||||
name = self._properties.get("name", "")
|
name = self._properties.get("name", "")
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateAnnotationLayerCommand(BaseCommand):
|
class UpdateAnnotationLayerCommand(BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[AnnotationLayer] = None
|
self._model: Optional[AnnotationLayer] = None
|
||||||
|
@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand):
|
||||||
return annotation_layer
|
return annotation_layer
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
name = self._properties.get("name", "")
|
name = self._properties.get("name", "")
|
||||||
self._model = AnnotationLayerDAO.find_by_id(self._model_id)
|
self._model = AnnotationLayerDAO.find_by_id(self._model_id)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(
|
def bulk_delete(
|
||||||
models: Optional[List[AnnotationLayer]], commit: bool = True
|
models: Optional[list[AnnotationLayer]], commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
item_ids = [model.id for model in models] if models else []
|
item_ids = [model.id for model in models] if models else []
|
||||||
try:
|
try:
|
||||||
|
@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO):
|
||||||
raise DAODeleteFailedError() from ex
|
raise DAODeleteFailedError() from ex
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_annotations(model_id: Union[int, List[int]]) -> bool:
|
def has_annotations(model_id: Union[int, list[int]]) -> bool:
|
||||||
if isinstance(model_id, list):
|
if isinstance(model_id, list):
|
||||||
return (
|
return (
|
||||||
db.session.query(AnnotationLayer)
|
db.session.query(AnnotationLayer)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkDeleteChartCommand(BaseCommand):
|
class BulkDeleteChartCommand(BaseCommand):
|
||||||
def __init__(self, model_ids: List[int]):
|
def __init__(self, model_ids: list[int]):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[Slice]] = None
|
self._models: Optional[list[Slice]] = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateChartCommand(CreateMixin, BaseCommand):
|
class CreateChartCommand(CreateMixin, BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
|
@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
|
||||||
datasource_type = self._properties["datasource_type"]
|
datasource_type = self._properties["datasource_type"]
|
||||||
datasource_id = self._properties["datasource_id"]
|
datasource_id = self._properties["datasource_id"]
|
||||||
dashboard_ids = self._properties.get("dashboards", [])
|
dashboard_ids = self._properties.get("dashboards", [])
|
||||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
|
|
||||||
# Validate/Populate datasource
|
# Validate/Populate datasource
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterator, Tuple
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand):
|
||||||
not_found = ChartNotFoundError
|
not_found = ChartNotFoundError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]:
|
def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
|
||||||
file_name = get_filename(model.slice_name, model.id)
|
file_name = get_filename(model.slice_name, model.id)
|
||||||
file_path = f"charts/{file_name}.yaml"
|
file_path = f"charts/{file_name}.yaml"
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand):
|
||||||
until it finds one that matches.
|
until it finds one that matches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import json
|
from typing import Any
|
||||||
from typing import Any, Dict, Set
|
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand):
|
||||||
dao = ChartDAO
|
dao = ChartDAO
|
||||||
model_name = "chart"
|
model_name = "chart"
|
||||||
prefix = "charts/"
|
prefix = "charts/"
|
||||||
schemas: Dict[str, Schema] = {
|
schemas: dict[str, Schema] = {
|
||||||
"charts/": ImportV1ChartSchema(),
|
"charts/": ImportV1ChartSchema(),
|
||||||
"datasets/": ImportV1DatasetSchema(),
|
"datasets/": ImportV1DatasetSchema(),
|
||||||
"databases/": ImportV1DatabaseSchema(),
|
"databases/": ImportV1DatabaseSchema(),
|
||||||
|
@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(
|
def _import(
|
||||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
# discover datasets associated with charts
|
# discover datasets associated with charts
|
||||||
dataset_uuids: Set[str] = set()
|
dataset_uuids: set[str] = set()
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("charts/"):
|
if file_name.startswith("charts/"):
|
||||||
dataset_uuids.add(config["dataset_uuid"])
|
dataset_uuids.add(config["dataset_uuid"])
|
||||||
|
|
||||||
# discover databases associated with datasets
|
# discover databases associated with datasets
|
||||||
database_uuids: Set[str] = set()
|
database_uuids: set[str] = set()
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
||||||
database_uuids.add(config["database_uuid"])
|
database_uuids.add(config["database_uuid"])
|
||||||
|
|
||||||
# import related databases
|
# import related databases
|
||||||
database_ids: Dict[str, int] = {}
|
database_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
||||||
database = import_database(session, config, overwrite=False)
|
database = import_database(session, config, overwrite=False)
|
||||||
database_ids[str(database.uuid)] = database.id
|
database_ids[str(database.uuid)] = database.id
|
||||||
|
|
||||||
# import datasets with the correct parent ref
|
# import datasets with the correct parent ref
|
||||||
datasets: Dict[str, SqlaTable] = {}
|
datasets: dict[str, SqlaTable] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if (
|
if (
|
||||||
file_name.startswith("datasets/")
|
file_name.startswith("datasets/")
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -28,7 +28,7 @@ from superset.models.slice import Slice
|
||||||
|
|
||||||
def import_chart(
|
def import_chart(
|
||||||
session: Session,
|
session: Session,
|
||||||
config: Dict[str, Any],
|
config: dict[str, Any],
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
ignore_permissions: bool = False,
|
ignore_permissions: bool = False,
|
||||||
) -> Slice:
|
) -> Slice:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
@ -42,14 +42,14 @@ from superset.models.slice import Slice
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_query_context_update(properties: Dict[str, Any]) -> bool:
|
def is_query_context_update(properties: dict[str, Any]) -> bool:
|
||||||
return set(properties) == {"query_context", "query_context_generation"} and bool(
|
return set(properties) == {"query_context", "query_context_generation"} and bool(
|
||||||
properties.get("query_context_generation")
|
properties.get("query_context_generation")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UpdateChartCommand(UpdateMixin, BaseCommand):
|
class UpdateChartCommand(UpdateMixin, BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[Slice] = None
|
self._model: Optional[Slice] = None
|
||||||
|
@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
dashboard_ids = self._properties.get("dashboards")
|
dashboard_ids = self._properties.get("dashboards")
|
||||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
|
|
||||||
# Validate if datasource_id is provided datasource_type is required
|
# Validate if datasource_id is provided datasource_type is required
|
||||||
datasource_id = self._properties.get("datasource_id")
|
datasource_id = self._properties.get("datasource_id")
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
# pylint: disable=arguments-renamed
|
# pylint: disable=arguments-renamed
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, TYPE_CHECKING
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class ChartDAO(BaseDAO):
|
||||||
base_filter = ChartFilter
|
base_filter = ChartFilter
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
|
def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None:
|
||||||
item_ids = [model.id for model in models] if models else []
|
item_ids = [model.id for model in models] if models else []
|
||||||
# bulk delete, first delete related data
|
# bulk delete, first delete related data
|
||||||
if models:
|
if models:
|
||||||
|
@ -71,7 +71,7 @@ class ChartDAO(BaseDAO):
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
|
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
|
||||||
ids = [chart.id for chart in charts]
|
ids = [chart.id for chart in charts]
|
||||||
return [
|
return [
|
||||||
star.obj_id
|
star.obj_id
|
||||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import simplejson
|
import simplejson
|
||||||
from flask import current_app, g, make_response, request, Response
|
from flask import current_app, g, make_response, request, Response
|
||||||
|
@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi):
|
||||||
return self._get_data_response(command, True)
|
return self._get_data_response(command, True)
|
||||||
|
|
||||||
def _run_async(
|
def _run_async(
|
||||||
self, form_data: Dict[str, Any], command: ChartDataCommand
|
self, form_data: dict[str, Any], command: ChartDataCommand
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Execute command as an async query.
|
Execute command as an async query.
|
||||||
|
@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi):
|
||||||
|
|
||||||
def _send_chart_response(
|
def _send_chart_response(
|
||||||
self,
|
self,
|
||||||
result: Dict[Any, Any],
|
result: dict[Any, Any],
|
||||||
form_data: Optional[Dict[str, Any]] = None,
|
form_data: dict[str, Any] | None = None,
|
||||||
datasource: Optional[Union[BaseDatasource, Query]] = None,
|
datasource: BaseDatasource | Query | None = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
result_type = result["query_context"].result_type
|
result_type = result["query_context"].result_type
|
||||||
result_format = result["query_context"].result_format
|
result_format = result["query_context"].result_format
|
||||||
|
@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi):
|
||||||
self,
|
self,
|
||||||
command: ChartDataCommand,
|
command: ChartDataCommand,
|
||||||
force_cached: bool = False,
|
force_cached: bool = False,
|
||||||
form_data: Optional[Dict[str, Any]] = None,
|
form_data: dict[str, Any] | None = None,
|
||||||
datasource: Optional[Union[BaseDatasource, Query]] = None,
|
datasource: BaseDatasource | Query | None = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
try:
|
try:
|
||||||
result = command.run(force_cached=force_cached)
|
result = command.run(force_cached=force_cached)
|
||||||
|
@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi):
|
||||||
return self._send_chart_response(result, form_data, datasource)
|
return self._send_chart_response(result, form_data, datasource)
|
||||||
|
|
||||||
# pylint: disable=invalid-name, no-self-use
|
# pylint: disable=invalid-name, no-self-use
|
||||||
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
|
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
|
||||||
return QueryContextCacheLoader.load(cache_key)
|
return QueryContextCacheLoader.load(cache_key)
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def _create_query_context_from_form(
|
def _create_query_context_from_form(
|
||||||
self, form_data: Dict[str, Any]
|
self, form_data: dict[str, Any]
|
||||||
) -> QueryContext:
|
) -> QueryContext:
|
||||||
try:
|
try:
|
||||||
return ChartDataQueryContextSchema().load(form_data)
|
return ChartDataQueryContextSchema().load(form_data)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import Request
|
from flask import Request
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
|
||||||
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
||||||
self._async_channel_id = jwt_data["channel"]
|
self._async_channel_id = jwt_data["channel"]
|
||||||
|
|
||||||
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
|
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
|
||||||
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
|
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
|
||||||
load_chart_data_into_cache.delay(job_metadata, form_data)
|
load_chart_data_into_cache.delay(job_metadata, form_data)
|
||||||
return job_metadata
|
return job_metadata
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand):
|
||||||
def __init__(self, query_context: QueryContext):
|
def __init__(self, query_context: QueryContext):
|
||||||
self._query_context = query_context
|
self._query_context = query_context
|
||||||
|
|
||||||
def run(self, **kwargs: Any) -> Dict[str, Any]:
|
def run(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
# caching is handled in query_context.get_df_payload
|
# caching is handled in query_context.get_df_payload
|
||||||
# (also evals `force` property)
|
# (also evals `force` property)
|
||||||
cache_query_context = kwargs.get("cache", False)
|
cache_query_context = kwargs.get("cache", False)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from superset import cache
|
from superset import cache
|
||||||
from superset.charts.commands.exceptions import ChartDataCacheLoadError
|
from superset.charts.commands.exceptions import ChartDataCacheLoadError
|
||||||
|
@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError
|
||||||
|
|
||||||
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
|
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(cache_key: str) -> Dict[str, Any]:
|
def load(cache_key: str) -> dict[str, Any]:
|
||||||
cache_value = cache.get(cache_key)
|
cache_value = cache.get(cache_key)
|
||||||
if not cache_value:
|
if not cache_value:
|
||||||
raise ChartDataCacheLoadError("Cached data not found")
|
raise ChartDataCacheLoadError("Cached data not found")
|
||||||
|
|
|
@ -27,7 +27,7 @@ for these chart types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
@ -45,14 +45,14 @@ if TYPE_CHECKING:
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
|
||||||
|
|
||||||
def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
|
def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]:
|
||||||
"""
|
"""
|
||||||
Sort columns when combining metrics.
|
Sort columns when combining metrics.
|
||||||
|
|
||||||
MultiIndex labels have the metric name as the last element in the
|
MultiIndex labels have the metric name as the last element in the
|
||||||
tuple. We want to sort these according to the list of passed metrics.
|
tuple. We want to sort these according to the list of passed metrics.
|
||||||
"""
|
"""
|
||||||
parts: List[Any] = list(label)
|
parts: list[Any] = list(label)
|
||||||
metric = parts[-1]
|
metric = parts[-1]
|
||||||
parts[-1] = metrics.index(metric)
|
parts[-1] = metrics.index(metric)
|
||||||
return tuple(parts)
|
return tuple(parts)
|
||||||
|
@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...
|
||||||
|
|
||||||
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
|
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
rows: List[str],
|
rows: list[str],
|
||||||
columns: List[str],
|
columns: list[str],
|
||||||
metrics: List[str],
|
metrics: list[str],
|
||||||
aggfunc: str = "Sum",
|
aggfunc: str = "Sum",
|
||||||
transpose_pivot: bool = False,
|
transpose_pivot: bool = False,
|
||||||
combine_metrics: bool = False,
|
combine_metrics: bool = False,
|
||||||
|
@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str:
|
||||||
"""
|
"""
|
||||||
List unique values in a series.
|
List unique values in a series.
|
||||||
"""
|
"""
|
||||||
return ", ".join(set(str(v) for v in pd.Series.unique(series)))
|
return ", ".join({str(v) for v in pd.Series.unique(series)})
|
||||||
|
|
||||||
|
|
||||||
pivot_v2_aggfunc_map = {
|
pivot_v2_aggfunc_map = {
|
||||||
|
@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = {
|
||||||
|
|
||||||
def pivot_table_v2(
|
def pivot_table_v2(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
form_data: Dict[str, Any],
|
form_data: dict[str, Any],
|
||||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
|
@ -249,7 +249,7 @@ def pivot_table_v2(
|
||||||
|
|
||||||
def pivot_table(
|
def pivot_table(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
form_data: Dict[str, Any],
|
form_data: dict[str, Any],
|
||||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
|
@ -285,7 +285,7 @@ def pivot_table(
|
||||||
|
|
||||||
def table(
|
def table(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
form_data: Dict[str, Any],
|
form_data: dict[str, Any],
|
||||||
datasource: Optional[ # pylint: disable=unused-argument
|
datasource: Optional[ # pylint: disable=unused-argument
|
||||||
Union["BaseDatasource", "Query"]
|
Union["BaseDatasource", "Query"]
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -315,10 +315,10 @@ post_processors = {
|
||||||
|
|
||||||
|
|
||||||
def apply_post_process(
|
def apply_post_process(
|
||||||
result: Dict[Any, Any],
|
result: dict[Any, Any],
|
||||||
form_data: Optional[Dict[str, Any]] = None,
|
form_data: Optional[dict[str, Any]] = None,
|
||||||
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
|
||||||
) -> Dict[Any, Any]:
|
) -> dict[Any, Any]:
|
||||||
form_data = form_data or {}
|
form_data = form_data or {}
|
||||||
|
|
||||||
viz_type = form_data.get("viz_type")
|
viz_type = form_data.get("viz_type")
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||||
|
@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema):
|
||||||
|
|
||||||
|
|
||||||
class ChartDataQueryContextSchema(Schema):
|
class ChartDataQueryContextSchema(Schema):
|
||||||
query_context_factory: Optional[QueryContextFactory] = None
|
query_context_factory: QueryContextFactory | None = None
|
||||||
datasource = fields.Nested(ChartDataDatasourceSchema)
|
datasource = fields.Nested(ChartDataDatasourceSchema)
|
||||||
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
|
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
|
||||||
custom_cache_timeout = fields.Integer(
|
custom_cache_timeout = fields.Integer(
|
||||||
|
@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema):
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@post_load
|
@post_load
|
||||||
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
|
def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext:
|
||||||
query_context = self.get_query_context_factory().create(**data)
|
query_context = self.get_query_context_factory().create(**data)
|
||||||
return query_context
|
return query_context
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
from zipfile import is_zipfile, ZipFile
|
from zipfile import is_zipfile, ZipFile
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
@ -309,7 +309,7 @@ else:
|
||||||
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
|
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
|
||||||
|
|
||||||
path_object = Path(path)
|
path_object = Path(path)
|
||||||
files: List[Path] = []
|
files: list[Path] = []
|
||||||
if path_object.is_file():
|
if path_object.is_file():
|
||||||
files.append(path_object)
|
files.append(path_object)
|
||||||
elif path_object.exists() and not recursive:
|
elif path_object.exists() and not recursive:
|
||||||
|
@ -363,7 +363,7 @@ else:
|
||||||
sync_metrics = "metrics" in sync_array
|
sync_metrics = "metrics" in sync_array
|
||||||
|
|
||||||
path_object = Path(path)
|
path_object = Path(path)
|
||||||
files: List[Path] = []
|
files: list[Path] = []
|
||||||
if path_object.is_file():
|
if path_object.is_file():
|
||||||
files.append(path_object)
|
files.append(path_object)
|
||||||
elif path_object.exists() and not recursive:
|
elif path_object.exists() and not recursive:
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
@ -40,7 +40,7 @@ def superset() -> None:
|
||||||
"""This is a management script for the Superset application."""
|
"""This is a management script for the Superset application."""
|
||||||
|
|
||||||
@app.shell_context_processor
|
@app.shell_context_processor
|
||||||
def make_shell_context() -> Dict[str, Any]:
|
def make_shell_context() -> dict[str, Any]:
|
||||||
return dict(app=app, db=db)
|
return dict(app=app, db=db)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,5 +79,5 @@ def version(verbose: bool) -> None:
|
||||||
)
|
)
|
||||||
print(Fore.BLUE + "-=" * 15)
|
print(Fore.BLUE + "-=" * 15)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("[DB] : " + "{}".format(db.engine))
|
print("[DB] : " + f"{db.engine}")
|
||||||
print(Style.RESET_ALL)
|
print(Style.RESET_ALL)
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
import json
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Set, Tuple
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
|
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
|
||||||
|
@ -102,7 +101,7 @@ def native_filters() -> None:
|
||||||
)
|
)
|
||||||
def upgrade(
|
def upgrade(
|
||||||
all_: bool, # pylint: disable=unused-argument
|
all_: bool, # pylint: disable=unused-argument
|
||||||
dashboard_ids: Tuple[int, ...],
|
dashboard_ids: tuple[int, ...],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upgrade legacy filter-box charts to native dashboard filters.
|
Upgrade legacy filter-box charts to native dashboard filters.
|
||||||
|
@ -251,7 +250,7 @@ def upgrade(
|
||||||
)
|
)
|
||||||
def downgrade(
|
def downgrade(
|
||||||
all_: bool, # pylint: disable=unused-argument
|
all_: bool, # pylint: disable=unused-argument
|
||||||
dashboard_ids: Tuple[int, ...],
|
dashboard_ids: tuple[int, ...],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
|
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
|
||||||
|
@ -347,7 +346,7 @@ def downgrade(
|
||||||
)
|
)
|
||||||
def cleanup(
|
def cleanup(
|
||||||
all_: bool, # pylint: disable=unused-argument
|
all_: bool, # pylint: disable=unused-argument
|
||||||
dashboard_ids: Tuple[int, ...],
|
dashboard_ids: tuple[int, ...],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Cleanup obsolete legacy filter-box charts and interim metadata.
|
Cleanup obsolete legacy filter-box charts and interim metadata.
|
||||||
|
@ -355,7 +354,7 @@ def cleanup(
|
||||||
Note this operation is irreversible.
|
Note this operation is irreversible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
slice_ids: Set[int] = set()
|
slice_ids: set[int] = set()
|
||||||
|
|
||||||
# Cleanup the dashboard which contains legacy fields used for downgrading.
|
# Cleanup the dashboard which contains legacy fields used for downgrading.
|
||||||
for dashboard in (
|
for dashboard in (
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Type, Union
|
from typing import Union
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery.utils.abstract import CallableTask
|
from celery.utils.abstract import CallableTask
|
||||||
|
@ -75,7 +75,7 @@ def compute_thumbnails(
|
||||||
|
|
||||||
def compute_generic_thumbnail(
|
def compute_generic_thumbnail(
|
||||||
friendly_type: str,
|
friendly_type: str,
|
||||||
model_cls: Union[Type[Dashboard], Type[Slice]],
|
model_cls: Union[type[Dashboard], type[Slice]],
|
||||||
model_id: int,
|
model_id: int,
|
||||||
compute_func: CallableTask,
|
compute_func: CallableTask,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class BaseCommand(ABC):
|
||||||
|
|
||||||
class CreateMixin: # pylint: disable=too-few-public-methods
|
class CreateMixin: # pylint: disable=too-few-public-methods
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
|
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
|
||||||
"""
|
"""
|
||||||
Populate list of owners, defaulting to the current user if `owner_ids` is
|
Populate list of owners, defaulting to the current user if `owner_ids` is
|
||||||
undefined or empty. If current user is missing in `owner_ids`, current user
|
undefined or empty. If current user is missing in `owner_ids`, current user
|
||||||
|
@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
class UpdateMixin: # pylint: disable=too-few-public-methods
|
class UpdateMixin: # pylint: disable=too-few-public-methods
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
|
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
|
||||||
"""
|
"""
|
||||||
Populate list of owners. If current user is missing in `owner_ids`, current user
|
Populate list of owners. If current user is missing in `owner_ids`, current user
|
||||||
is added unless belonging to the Admin role.
|
is added unless belonging to the Admin role.
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -59,7 +59,7 @@ class CommandInvalidError(CommandException):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: str = "",
|
message: str = "",
|
||||||
exceptions: Optional[List[ValidationError]] = None,
|
exceptions: Optional[list[ValidationError]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._exceptions = exceptions or []
|
self._exceptions = exceptions or []
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
@ -67,14 +67,14 @@ class CommandInvalidError(CommandException):
|
||||||
def append(self, exception: ValidationError) -> None:
|
def append(self, exception: ValidationError) -> None:
|
||||||
self._exceptions.append(exception)
|
self._exceptions.append(exception)
|
||||||
|
|
||||||
def extend(self, exceptions: List[ValidationError]) -> None:
|
def extend(self, exceptions: list[ValidationError]) -> None:
|
||||||
self._exceptions.extend(exceptions)
|
self._exceptions.extend(exceptions)
|
||||||
|
|
||||||
def get_list_classnames(self) -> List[str]:
|
def get_list_classnames(self) -> list[str]:
|
||||||
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
|
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
|
||||||
|
|
||||||
def normalized_messages(self) -> Dict[Any, Any]:
|
def normalized_messages(self) -> dict[Any, Any]:
|
||||||
errors: Dict[Any, Any] = {}
|
errors: dict[Any, Any] = {}
|
||||||
for exception in self._exceptions:
|
for exception in self._exceptions:
|
||||||
errors.update(exception.normalized_messages())
|
errors.update(exception.normalized_messages())
|
||||||
return errors
|
return errors
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Iterator, Tuple
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand):
|
||||||
Command that exports all databases, datasets, charts, dashboards and saved queries.
|
Command that exports all databases, datasets, charts, dashboards and saved queries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run(self) -> Iterator[Tuple[str, str]]:
|
def run(self) -> Iterator[tuple[str, str]]:
|
||||||
metadata = {
|
metadata = {
|
||||||
"version": EXPORT_VERSION,
|
"version": EXPORT_VERSION,
|
||||||
"type": "assets",
|
"type": "assets",
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Iterator, List, Tuple, Type
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from flask_appbuilder import Model
|
from flask_appbuilder import Model
|
||||||
|
@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml"
|
||||||
|
|
||||||
|
|
||||||
class ExportModelsCommand(BaseCommand):
|
class ExportModelsCommand(BaseCommand):
|
||||||
dao: Type[BaseDAO] = BaseDAO
|
dao: type[BaseDAO] = BaseDAO
|
||||||
not_found: Type[CommandException] = CommandException
|
not_found: type[CommandException] = CommandException
|
||||||
|
|
||||||
def __init__(self, model_ids: List[int], export_related: bool = True):
|
def __init__(self, model_ids: list[int], export_related: bool = True):
|
||||||
self.model_ids = model_ids
|
self.model_ids = model_ids
|
||||||
self.export_related = export_related
|
self.export_related = export_related
|
||||||
|
|
||||||
# this will be set when calling validate()
|
# this will be set when calling validate()
|
||||||
self._models: List[Model] = []
|
self._models: list[Model] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]:
|
def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]:
|
||||||
raise NotImplementedError("Subclasses MUST implement _export")
|
raise NotImplementedError("Subclasses MUST implement _export")
|
||||||
|
|
||||||
def run(self) -> Iterator[Tuple[str, str]]:
|
def run(self) -> Iterator[tuple[str, str]]:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Optional
|
||||||
|
|
||||||
from marshmallow import Schema, validate
|
from marshmallow import Schema, validate
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
|
@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand):
|
||||||
dao = BaseDAO
|
dao = BaseDAO
|
||||||
model_name = "model"
|
model_name = "model"
|
||||||
prefix = ""
|
prefix = ""
|
||||||
schemas: Dict[str, Schema] = {}
|
schemas: dict[str, Schema] = {}
|
||||||
import_error = CommandException
|
import_error = CommandException
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
|
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
|
||||||
self.ssh_tunnel_passwords: Dict[str, str] = (
|
self.ssh_tunnel_passwords: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_passwords") or {}
|
kwargs.get("ssh_tunnel_passwords") or {}
|
||||||
)
|
)
|
||||||
self.ssh_tunnel_private_keys: Dict[str, str] = (
|
self.ssh_tunnel_private_keys: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_private_keys") or {}
|
kwargs.get("ssh_tunnel_private_keys") or {}
|
||||||
)
|
)
|
||||||
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
|
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
||||||
)
|
)
|
||||||
self.overwrite: bool = kwargs.get("overwrite", False)
|
self.overwrite: bool = kwargs.get("overwrite", False)
|
||||||
self._configs: Dict[str, Any] = {}
|
self._configs: dict[str, Any] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(
|
def _import(
|
||||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError("Subclasses MUST implement _import")
|
raise NotImplementedError("Subclasses MUST implement _import")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_uuids(cls) -> Set[str]:
|
def _get_uuids(cls) -> set[str]:
|
||||||
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
|
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
|
@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand):
|
||||||
raise self.import_error() from ex
|
raise self.import_error() from ex
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
|
|
||||||
# verify that the metadata file is present and valid
|
# verify that the metadata file is present and valid
|
||||||
try:
|
try:
|
||||||
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
|
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
metadata = None
|
metadata = None
|
||||||
|
@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
|
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
|
||||||
self, exceptions: List[ValidationError]
|
self, exceptions: list[ValidationError]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""check if the object exists and shouldn't be overwritten"""
|
"""check if the object exists and shouldn't be overwritten"""
|
||||||
if not self.overwrite:
|
if not self.overwrite:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
|
@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
and will overwrite everything.
|
and will overwrite everything.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
schemas: Dict[str, Schema] = {
|
schemas: dict[str, Schema] = {
|
||||||
"charts/": ImportV1ChartSchema(),
|
"charts/": ImportV1ChartSchema(),
|
||||||
"dashboards/": ImportV1DashboardSchema(),
|
"dashboards/": ImportV1DashboardSchema(),
|
||||||
"datasets/": ImportV1DatasetSchema(),
|
"datasets/": ImportV1DatasetSchema(),
|
||||||
|
@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
}
|
}
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
|
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
|
||||||
self.ssh_tunnel_passwords: Dict[str, str] = (
|
self.ssh_tunnel_passwords: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_passwords") or {}
|
kwargs.get("ssh_tunnel_passwords") or {}
|
||||||
)
|
)
|
||||||
self.ssh_tunnel_private_keys: Dict[str, str] = (
|
self.ssh_tunnel_private_keys: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_private_keys") or {}
|
kwargs.get("ssh_tunnel_private_keys") or {}
|
||||||
)
|
)
|
||||||
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
|
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
|
||||||
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
|
||||||
)
|
)
|
||||||
self._configs: Dict[str, Any] = {}
|
self._configs: dict[str, Any] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(session: Session, configs: Dict[str, Any]) -> None:
|
def _import(session: Session, configs: dict[str, Any]) -> None:
|
||||||
# import databases first
|
# import databases first
|
||||||
database_ids: Dict[str, int] = {}
|
database_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("databases/"):
|
if file_name.startswith("databases/"):
|
||||||
database = import_database(session, config, overwrite=True)
|
database = import_database(session, config, overwrite=True)
|
||||||
|
@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
import_saved_query(session, config, overwrite=True)
|
import_saved_query(session, config, overwrite=True)
|
||||||
|
|
||||||
# import datasets
|
# import datasets
|
||||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
dataset_info: dict[str, dict[str, Any]] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("datasets/"):
|
if file_name.startswith("datasets/"):
|
||||||
config["database_id"] = database_ids[config["database_uuid"]]
|
config["database_id"] = database_ids[config["database_uuid"]]
|
||||||
|
@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
}
|
}
|
||||||
|
|
||||||
# import charts
|
# import charts
|
||||||
chart_ids: Dict[str, int] = {}
|
chart_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("charts/"):
|
if file_name.startswith("charts/"):
|
||||||
config.update(dataset_info[config["dataset_uuid"]])
|
config.update(dataset_info[config["dataset_uuid"]])
|
||||||
|
@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
dashboard = import_dashboard(session, config, overwrite=True)
|
dashboard = import_dashboard(session, config, overwrite=True)
|
||||||
|
|
||||||
# set ref in the dashboard_slices table
|
# set ref in the dashboard_slices table
|
||||||
dashboard_chart_ids: List[Dict[str, int]] = []
|
dashboard_chart_ids: list[dict[str, int]] = []
|
||||||
for uuid in find_chart_uuids(config["position"]):
|
for uuid in find_chart_uuids(config["position"]):
|
||||||
if uuid not in chart_ids:
|
if uuid not in chart_ids:
|
||||||
break
|
break
|
||||||
|
@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand):
|
||||||
raise ImportFailedError() from ex
|
raise ImportFailedError() from ex
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
|
|
||||||
# verify that the metadata file is present and valid
|
# verify that the metadata file is present and valid
|
||||||
try:
|
try:
|
||||||
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
|
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
|
||||||
except ValidationError as exc:
|
except ValidationError as exc:
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
|
|
||||||
dao = BaseDAO
|
dao = BaseDAO
|
||||||
model_name = "model"
|
model_name = "model"
|
||||||
schemas: Dict[str, Schema] = {
|
schemas: dict[str, Schema] = {
|
||||||
"charts/": ImportV1ChartSchema(),
|
"charts/": ImportV1ChartSchema(),
|
||||||
"dashboards/": ImportV1DashboardSchema(),
|
"dashboards/": ImportV1DashboardSchema(),
|
||||||
"datasets/": ImportV1DatasetSchema(),
|
"datasets/": ImportV1DatasetSchema(),
|
||||||
|
@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
}
|
}
|
||||||
import_error = CommandException
|
import_error = CommandException
|
||||||
|
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
super().__init__(contents, *args, **kwargs)
|
super().__init__(contents, *args, **kwargs)
|
||||||
self.force_data = kwargs.get("force_data", False)
|
self.force_data = kwargs.get("force_data", False)
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
raise self.import_error() from ex
|
raise self.import_error() from ex
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_uuids(cls) -> Set[str]:
|
def _get_uuids(cls) -> set[str]:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return (
|
return (
|
||||||
ImportDatabasesCommand._get_uuids()
|
ImportDatabasesCommand._get_uuids()
|
||||||
|
@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
|
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
|
||||||
session: Session,
|
session: Session,
|
||||||
configs: Dict[str, Any],
|
configs: dict[str, Any],
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
force_data: bool = False,
|
force_data: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# import databases
|
# import databases
|
||||||
database_ids: Dict[str, int] = {}
|
database_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("databases/"):
|
if file_name.startswith("databases/"):
|
||||||
database = import_database(
|
database = import_database(
|
||||||
|
@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
# database was created before its UUID was frozen, so it has a random UUID.
|
# database was created before its UUID was frozen, so it has a random UUID.
|
||||||
# We need to determine its ID so we can point the dataset to it.
|
# We need to determine its ID so we can point the dataset to it.
|
||||||
examples_db = get_example_database()
|
examples_db = get_example_database()
|
||||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
dataset_info: dict[str, dict[str, Any]] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("datasets/"):
|
if file_name.startswith("datasets/"):
|
||||||
# find the ID of the corresponding database
|
# find the ID of the corresponding database
|
||||||
|
@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
}
|
}
|
||||||
|
|
||||||
# import charts
|
# import charts
|
||||||
chart_ids: Dict[str, int] = {}
|
chart_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if (
|
if (
|
||||||
file_name.startswith("charts/")
|
file_name.startswith("charts/")
|
||||||
|
@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand):
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# import dashboards
|
# import dashboards
|
||||||
dashboard_chart_ids: List[Tuple[int, int]] = []
|
dashboard_chart_ids: list[tuple[int, int]] = []
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("dashboards/"):
|
if file_name.startswith("dashboards/"):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path, PurePosixPath
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -46,7 +46,7 @@ class MetadataSchema(Schema):
|
||||||
timestamp = fields.DateTime()
|
timestamp = fields.DateTime()
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
|
def load_yaml(file_name: str, content: str) -> dict[str, Any]:
|
||||||
"""Try to load a YAML file"""
|
"""Try to load a YAML file"""
|
||||||
try:
|
try:
|
||||||
return yaml.safe_load(content)
|
return yaml.safe_load(content)
|
||||||
|
@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
|
||||||
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
|
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
|
||||||
|
|
||||||
|
|
||||||
def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
|
def load_metadata(contents: dict[str, str]) -> dict[str, str]:
|
||||||
"""Apply validation and load a metadata file"""
|
"""Apply validation and load a metadata file"""
|
||||||
if METADATA_FILE_NAME not in contents:
|
if METADATA_FILE_NAME not in contents:
|
||||||
# if the contents have no METADATA_FILE_NAME this is probably
|
# if the contents have no METADATA_FILE_NAME this is probably
|
||||||
|
@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
def validate_metadata_type(
|
def validate_metadata_type(
|
||||||
metadata: Optional[Dict[str, str]],
|
metadata: Optional[dict[str, str]],
|
||||||
type_: str,
|
type_: str,
|
||||||
exceptions: List[ValidationError],
|
exceptions: list[ValidationError],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
|
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
|
||||||
if metadata and "type" in metadata:
|
if metadata and "type" in metadata:
|
||||||
|
@ -96,35 +96,35 @@ def validate_metadata_type(
|
||||||
|
|
||||||
# pylint: disable=too-many-locals,too-many-arguments
|
# pylint: disable=too-many-locals,too-many-arguments
|
||||||
def load_configs(
|
def load_configs(
|
||||||
contents: Dict[str, str],
|
contents: dict[str, str],
|
||||||
schemas: Dict[str, Schema],
|
schemas: dict[str, Schema],
|
||||||
passwords: Dict[str, str],
|
passwords: dict[str, str],
|
||||||
exceptions: List[ValidationError],
|
exceptions: list[ValidationError],
|
||||||
ssh_tunnel_passwords: Dict[str, str],
|
ssh_tunnel_passwords: dict[str, str],
|
||||||
ssh_tunnel_private_keys: Dict[str, str],
|
ssh_tunnel_private_keys: dict[str, str],
|
||||||
ssh_tunnel_priv_key_passwords: Dict[str, str],
|
ssh_tunnel_priv_key_passwords: dict[str, str],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
configs: Dict[str, Any] = {}
|
configs: dict[str, Any] = {}
|
||||||
|
|
||||||
# load existing databases so we can apply the password validation
|
# load existing databases so we can apply the password validation
|
||||||
db_passwords: Dict[str, str] = {
|
db_passwords: dict[str, str] = {
|
||||||
str(uuid): password
|
str(uuid): password
|
||||||
for uuid, password in db.session.query(Database.uuid, Database.password).all()
|
for uuid, password in db.session.query(Database.uuid, Database.password).all()
|
||||||
}
|
}
|
||||||
# load existing ssh_tunnels so we can apply the password validation
|
# load existing ssh_tunnels so we can apply the password validation
|
||||||
db_ssh_tunnel_passwords: Dict[str, str] = {
|
db_ssh_tunnel_passwords: dict[str, str] = {
|
||||||
str(uuid): password
|
str(uuid): password
|
||||||
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
|
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
|
||||||
}
|
}
|
||||||
# load existing ssh_tunnels so we can apply the private_key validation
|
# load existing ssh_tunnels so we can apply the private_key validation
|
||||||
db_ssh_tunnel_private_keys: Dict[str, str] = {
|
db_ssh_tunnel_private_keys: dict[str, str] = {
|
||||||
str(uuid): private_key
|
str(uuid): private_key
|
||||||
for uuid, private_key in db.session.query(
|
for uuid, private_key in db.session.query(
|
||||||
SSHTunnel.uuid, SSHTunnel.private_key
|
SSHTunnel.uuid, SSHTunnel.private_key
|
||||||
).all()
|
).all()
|
||||||
}
|
}
|
||||||
# load existing ssh_tunnels so we can apply the private_key_password validation
|
# load existing ssh_tunnels so we can apply the private_key_password validation
|
||||||
db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
|
db_ssh_tunnel_priv_key_passws: dict[str, str] = {
|
||||||
str(uuid): private_key_password
|
str(uuid): private_key_password
|
||||||
for uuid, private_key_password in db.session.query(
|
for uuid, private_key_password in db.session.query(
|
||||||
SSHTunnel.uuid, SSHTunnel.private_key_password
|
SSHTunnel.uuid, SSHTunnel.private_key_password
|
||||||
|
@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]:
|
def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
remove_root(file_name): bundle.read(file_name).decode()
|
remove_root(file_name): bundle.read(file_name).decode()
|
||||||
for file_name in bundle.namelist()
|
for file_name in bundle.namelist()
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Optional, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_appbuilder.security.sqla.models import Role, User
|
from flask_appbuilder.security.sqla.models import Role, User
|
||||||
|
@ -37,9 +37,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def populate_owners(
|
def populate_owners(
|
||||||
owner_ids: Optional[List[int]],
|
owner_ids: list[int] | None,
|
||||||
default_to_user: bool,
|
default_to_user: bool,
|
||||||
) -> List[User]:
|
) -> list[User]:
|
||||||
"""
|
"""
|
||||||
Helper function for commands, will fetch all users from owners id's
|
Helper function for commands, will fetch all users from owners id's
|
||||||
|
|
||||||
|
@ -63,13 +63,13 @@ def populate_owners(
|
||||||
return owners
|
return owners
|
||||||
|
|
||||||
|
|
||||||
def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
|
def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
|
||||||
"""
|
"""
|
||||||
Helper function for commands, will fetch all roles from roles id's
|
Helper function for commands, will fetch all roles from roles id's
|
||||||
:raises RolesNotFoundValidationError: If a role in the input list is not found
|
:raises RolesNotFoundValidationError: If a role in the input list is not found
|
||||||
:param role_ids: A List of roles by id's
|
:param role_ids: A List of roles by id's
|
||||||
"""
|
"""
|
||||||
roles: List[Role] = []
|
roles: list[Role] = []
|
||||||
if role_ids:
|
if role_ids:
|
||||||
roles = security_manager.find_roles_by_id(role_ids)
|
roles = security_manager.find_roles_by_id(role_ids)
|
||||||
if len(roles) != len(role_ids):
|
if len(roles) != len(role_ids):
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Set
|
|
||||||
|
|
||||||
|
|
||||||
class ChartDataResultFormat(str, Enum):
|
class ChartDataResultFormat(str, Enum):
|
||||||
|
@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum):
|
||||||
XLSX = "xlsx"
|
XLSX = "xlsx"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def table_like(cls) -> Set["ChartDataResultFormat"]:
|
def table_like(cls) -> set["ChartDataResultFormat"]:
|
||||||
return {cls.CSV} | {cls.XLSX}
|
return {cls.CSV} | {cls.XLSX}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from flask_babel import _
|
from flask_babel import _
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ def _get_datasource(
|
||||||
|
|
||||||
def _get_columns(
|
def _get_columns(
|
||||||
query_context: QueryContext, query_obj: QueryObject, _: bool
|
query_context: QueryContext, query_obj: QueryObject, _: bool
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
return {
|
return {
|
||||||
"data": [
|
"data": [
|
||||||
|
@ -65,7 +65,7 @@ def _get_columns(
|
||||||
|
|
||||||
def _get_timegrains(
|
def _get_timegrains(
|
||||||
query_context: QueryContext, query_obj: QueryObject, _: bool
|
query_context: QueryContext, query_obj: QueryObject, _: bool
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
return {
|
return {
|
||||||
"data": [
|
"data": [
|
||||||
|
@ -83,7 +83,7 @@ def _get_query(
|
||||||
query_context: QueryContext,
|
query_context: QueryContext,
|
||||||
query_obj: QueryObject,
|
query_obj: QueryObject,
|
||||||
_: bool,
|
_: bool,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
result = {"language": datasource.query_language}
|
result = {"language": datasource.query_language}
|
||||||
try:
|
try:
|
||||||
|
@ -96,8 +96,8 @@ def _get_query(
|
||||||
def _get_full(
|
def _get_full(
|
||||||
query_context: QueryContext,
|
query_context: QueryContext,
|
||||||
query_obj: QueryObject,
|
query_obj: QueryObject,
|
||||||
force_cached: Optional[bool] = False,
|
force_cached: bool | None = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
result_type = query_obj.result_type or query_context.result_type
|
result_type = query_obj.result_type or query_context.result_type
|
||||||
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
|
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
|
||||||
|
@ -141,7 +141,7 @@ def _get_full(
|
||||||
|
|
||||||
def _get_samples(
|
def _get_samples(
|
||||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
query_obj = copy.copy(query_obj)
|
query_obj = copy.copy(query_obj)
|
||||||
query_obj.is_timeseries = False
|
query_obj.is_timeseries = False
|
||||||
|
@ -162,7 +162,7 @@ def _get_samples(
|
||||||
|
|
||||||
def _get_drill_detail(
|
def _get_drill_detail(
|
||||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# todo(yongjie): Remove this function,
|
# todo(yongjie): Remove this function,
|
||||||
# when determining whether samples should be applied to the time filter.
|
# when determining whether samples should be applied to the time filter.
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
|
@ -183,13 +183,13 @@ def _get_drill_detail(
|
||||||
|
|
||||||
def _get_results(
|
def _get_results(
|
||||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
payload = _get_full(query_context, query_obj, force_cached)
|
payload = _get_full(query_context, query_obj, force_cached)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
_result_type_functions: Dict[
|
_result_type_functions: dict[
|
||||||
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]]
|
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
|
||||||
] = {
|
] = {
|
||||||
ChartDataResultType.COLUMNS: _get_columns,
|
ChartDataResultType.COLUMNS: _get_columns,
|
||||||
ChartDataResultType.TIMEGRAINS: _get_timegrains,
|
ChartDataResultType.TIMEGRAINS: _get_timegrains,
|
||||||
|
@ -210,7 +210,7 @@ def get_query_results(
|
||||||
query_context: QueryContext,
|
query_context: QueryContext,
|
||||||
query_obj: QueryObject,
|
query_obj: QueryObject,
|
||||||
force_cached: bool,
|
force_cached: bool,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Return result payload for a chart data request.
|
Return result payload for a chart data request.
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
|
from typing import Any, ClassVar, TYPE_CHECKING
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
@ -47,15 +47,15 @@ class QueryContext:
|
||||||
enforce_numerical_metrics: ClassVar[bool] = True
|
enforce_numerical_metrics: ClassVar[bool] = True
|
||||||
|
|
||||||
datasource: BaseDatasource
|
datasource: BaseDatasource
|
||||||
slice_: Optional[Slice] = None
|
slice_: Slice | None = None
|
||||||
queries: List[QueryObject]
|
queries: list[QueryObject]
|
||||||
form_data: Optional[Dict[str, Any]]
|
form_data: dict[str, Any] | None
|
||||||
result_type: ChartDataResultType
|
result_type: ChartDataResultType
|
||||||
result_format: ChartDataResultFormat
|
result_format: ChartDataResultFormat
|
||||||
force: bool
|
force: bool
|
||||||
custom_cache_timeout: Optional[int]
|
custom_cache_timeout: int | None
|
||||||
|
|
||||||
cache_values: Dict[str, Any]
|
cache_values: dict[str, Any]
|
||||||
|
|
||||||
_processor: QueryContextProcessor
|
_processor: QueryContextProcessor
|
||||||
|
|
||||||
|
@ -65,14 +65,14 @@ class QueryContext:
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
datasource: BaseDatasource,
|
datasource: BaseDatasource,
|
||||||
queries: List[QueryObject],
|
queries: list[QueryObject],
|
||||||
slice_: Optional[Slice],
|
slice_: Slice | None,
|
||||||
form_data: Optional[Dict[str, Any]],
|
form_data: dict[str, Any] | None,
|
||||||
result_type: ChartDataResultType,
|
result_type: ChartDataResultType,
|
||||||
result_format: ChartDataResultFormat,
|
result_format: ChartDataResultFormat,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
custom_cache_timeout: Optional[int] = None,
|
custom_cache_timeout: int | None = None,
|
||||||
cache_values: Dict[str, Any],
|
cache_values: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.datasource = datasource
|
self.datasource = datasource
|
||||||
self.slice_ = slice_
|
self.slice_ = slice_
|
||||||
|
@ -88,18 +88,18 @@ class QueryContext:
|
||||||
def get_data(
|
def get_data(
|
||||||
self,
|
self,
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
) -> Union[str, List[Dict[str, Any]]]:
|
) -> str | list[dict[str, Any]]:
|
||||||
return self._processor.get_data(df)
|
return self._processor.get_data(df)
|
||||||
|
|
||||||
def get_payload(
|
def get_payload(
|
||||||
self,
|
self,
|
||||||
cache_query_context: Optional[bool] = False,
|
cache_query_context: bool | None = False,
|
||||||
force_cached: bool = False,
|
force_cached: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Returns the query results with both metadata and data"""
|
"""Returns the query results with both metadata and data"""
|
||||||
return self._processor.get_payload(cache_query_context, force_cached)
|
return self._processor.get_payload(cache_query_context, force_cached)
|
||||||
|
|
||||||
def get_cache_timeout(self) -> Optional[int]:
|
def get_cache_timeout(self) -> int | None:
|
||||||
if self.custom_cache_timeout is not None:
|
if self.custom_cache_timeout is not None:
|
||||||
return self.custom_cache_timeout
|
return self.custom_cache_timeout
|
||||||
if self.slice_ and self.slice_.cache_timeout is not None:
|
if self.slice_ and self.slice_.cache_timeout is not None:
|
||||||
|
@ -110,14 +110,14 @@ class QueryContext:
|
||||||
return self.datasource.database.cache_timeout
|
return self.datasource.database.cache_timeout
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
|
||||||
return self._processor.query_cache_key(query_obj, **kwargs)
|
return self._processor.query_cache_key(query_obj, **kwargs)
|
||||||
|
|
||||||
def get_df_payload(
|
def get_df_payload(
|
||||||
self,
|
self,
|
||||||
query_obj: QueryObject,
|
query_obj: QueryObject,
|
||||||
force_cached: Optional[bool] = False,
|
force_cached: bool | None = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return self._processor.get_df_payload(
|
return self._processor.get_df_payload(
|
||||||
query_obj=query_obj,
|
query_obj=query_obj,
|
||||||
force_cached=force_cached,
|
force_cached=force_cached,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from superset import app, db
|
from superset import app, db
|
||||||
from superset.charts.dao import ChartDAO
|
from superset.charts.dao import ChartDAO
|
||||||
|
@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
datasource: DatasourceDict,
|
datasource: DatasourceDict,
|
||||||
queries: List[Dict[str, Any]],
|
queries: list[dict[str, Any]],
|
||||||
form_data: Optional[Dict[str, Any]] = None,
|
form_data: dict[str, Any] | None = None,
|
||||||
result_type: Optional[ChartDataResultType] = None,
|
result_type: ChartDataResultType | None = None,
|
||||||
result_format: Optional[ChartDataResultFormat] = None,
|
result_format: ChartDataResultFormat | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
custom_cache_timeout: Optional[int] = None,
|
custom_cache_timeout: int | None = None,
|
||||||
) -> QueryContext:
|
) -> QueryContext:
|
||||||
datasource_model_instance = None
|
datasource_model_instance = None
|
||||||
if datasource:
|
if datasource:
|
||||||
|
@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||||
datasource_id=int(datasource["id"]),
|
datasource_id=int(datasource["id"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_slice(self, slice_id: Any) -> Optional[Slice]:
|
def _get_slice(self, slice_id: Any) -> Slice | None:
|
||||||
return ChartDAO.find_by_id(slice_id)
|
return ChartDAO.find_by_id(slice_id)
|
||||||
|
|
||||||
def _process_query_object(
|
def _process_query_object(
|
||||||
self,
|
self,
|
||||||
datasource: BaseDatasource,
|
datasource: BaseDatasource,
|
||||||
form_data: Optional[Dict[str, Any]],
|
form_data: dict[str, Any] | None,
|
||||||
query_object: QueryObject,
|
query_object: QueryObject,
|
||||||
) -> QueryObject:
|
) -> QueryObject:
|
||||||
self._apply_granularity(query_object, form_data, datasource)
|
self._apply_granularity(query_object, form_data, datasource)
|
||||||
|
@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||||
def _apply_granularity(
|
def _apply_granularity(
|
||||||
self,
|
self,
|
||||||
query_object: QueryObject,
|
query_object: QueryObject,
|
||||||
form_data: Optional[Dict[str, Any]],
|
form_data: dict[str, Any] | None,
|
||||||
datasource: BaseDatasource,
|
datasource: BaseDatasource,
|
||||||
) -> None:
|
) -> None:
|
||||||
temporal_columns = {
|
temporal_columns = {
|
||||||
|
|
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
|
from typing import Any, ClassVar, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -77,8 +77,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CachedTimeOffset(TypedDict):
|
class CachedTimeOffset(TypedDict):
|
||||||
df: pd.DataFrame
|
df: pd.DataFrame
|
||||||
queries: List[str]
|
queries: list[str]
|
||||||
cache_keys: List[Optional[str]]
|
cache_keys: list[str | None]
|
||||||
|
|
||||||
|
|
||||||
class QueryContextProcessor:
|
class QueryContextProcessor:
|
||||||
|
@ -102,8 +102,8 @@ class QueryContextProcessor:
|
||||||
enforce_numerical_metrics: ClassVar[bool] = True
|
enforce_numerical_metrics: ClassVar[bool] = True
|
||||||
|
|
||||||
def get_df_payload(
|
def get_df_payload(
|
||||||
self, query_obj: QueryObject, force_cached: Optional[bool] = False
|
self, query_obj: QueryObject, force_cached: bool | None = False
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handles caching around the df payload retrieval"""
|
"""Handles caching around the df payload retrieval"""
|
||||||
cache_key = self.query_cache_key(query_obj)
|
cache_key = self.query_cache_key(query_obj)
|
||||||
timeout = self.get_cache_timeout()
|
timeout = self.get_cache_timeout()
|
||||||
|
@ -181,7 +181,7 @@ class QueryContextProcessor:
|
||||||
"label_map": label_map,
|
"label_map": label_map,
|
||||||
}
|
}
|
||||||
|
|
||||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
|
||||||
"""
|
"""
|
||||||
Returns a QueryObject cache key for objects in self.queries
|
Returns a QueryObject cache key for objects in self.queries
|
||||||
"""
|
"""
|
||||||
|
@ -248,8 +248,8 @@ class QueryContextProcessor:
|
||||||
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
|
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
|
||||||
# todo: should support "python_date_format" and "get_column" in each datasource
|
# todo: should support "python_date_format" and "get_column" in each datasource
|
||||||
def _get_timestamp_format(
|
def _get_timestamp_format(
|
||||||
source: BaseDatasource, column: Optional[str]
|
source: BaseDatasource, column: str | None
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
column_obj = source.get_column(column)
|
column_obj = source.get_column(column)
|
||||||
if (
|
if (
|
||||||
column_obj
|
column_obj
|
||||||
|
@ -315,9 +315,9 @@ class QueryContextProcessor:
|
||||||
query_context = self._query_context
|
query_context = self._query_context
|
||||||
# ensure query_object is immutable
|
# ensure query_object is immutable
|
||||||
query_object_clone = copy.copy(query_object)
|
query_object_clone = copy.copy(query_object)
|
||||||
queries: List[str] = []
|
queries: list[str] = []
|
||||||
cache_keys: List[Optional[str]] = []
|
cache_keys: list[str | None] = []
|
||||||
rv_dfs: List[pd.DataFrame] = [df]
|
rv_dfs: list[pd.DataFrame] = [df]
|
||||||
|
|
||||||
time_offsets = query_object.time_offsets
|
time_offsets = query_object.time_offsets
|
||||||
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
|
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
|
||||||
|
@ -449,7 +449,7 @@ class QueryContextProcessor:
|
||||||
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
|
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
|
||||||
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
|
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
|
||||||
|
|
||||||
def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
|
def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
|
||||||
if self._query_context.result_format in ChartDataResultFormat.table_like():
|
if self._query_context.result_format in ChartDataResultFormat.table_like():
|
||||||
include_index = not isinstance(df.index, pd.RangeIndex)
|
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||||
columns = list(df.columns)
|
columns = list(df.columns)
|
||||||
|
@ -470,9 +470,9 @@ class QueryContextProcessor:
|
||||||
|
|
||||||
def get_payload(
|
def get_payload(
|
||||||
self,
|
self,
|
||||||
cache_query_context: Optional[bool] = False,
|
cache_query_context: bool | None = False,
|
||||||
force_cached: bool = False,
|
force_cached: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Returns the query results with both metadata and data"""
|
"""Returns the query results with both metadata and data"""
|
||||||
|
|
||||||
# Get all the payloads from the QueryObjects
|
# Get all the payloads from the QueryObjects
|
||||||
|
@ -522,13 +522,13 @@ class QueryContextProcessor:
|
||||||
|
|
||||||
return generate_cache_key(cache_dict, key_prefix)
|
return generate_cache_key(cache_dict, key_prefix)
|
||||||
|
|
||||||
def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
|
def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
:param query_context:
|
:param query_context:
|
||||||
:param query_obj:
|
:param query_obj:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
|
annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj)
|
||||||
for annotation_layer in [
|
for annotation_layer in [
|
||||||
layer
|
layer
|
||||||
for layer in query_obj.annotation_layers
|
for layer in query_obj.annotation_layers
|
||||||
|
@ -541,7 +541,7 @@ class QueryContextProcessor:
|
||||||
return annotation_data
|
return annotation_data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
|
def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
|
||||||
annotation_data = {}
|
annotation_data = {}
|
||||||
annotation_layers = [
|
annotation_layers = [
|
||||||
layer
|
layer
|
||||||
|
@ -576,8 +576,8 @@ class QueryContextProcessor:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_viz_annotation_data(
|
def get_viz_annotation_data(
|
||||||
annotation_layer: Dict[str, Any], force: bool
|
annotation_layer: dict[str, Any], force: bool
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
chart = ChartDAO.find_by_id(annotation_layer["value"])
|
chart = ChartDAO.find_by_id(annotation_layer["value"])
|
||||||
if not chart:
|
if not chart:
|
||||||
raise QueryObjectValidationError(_("The chart does not exist"))
|
raise QueryObjectValidationError(_("The chart does not exist"))
|
||||||
|
|
|
@ -21,7 +21,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
|
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
|
@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
and druid. The query objects are constructed on the client.
|
and druid. The query objects are constructed on the client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
annotation_layers: List[Dict[str, Any]]
|
annotation_layers: list[dict[str, Any]]
|
||||||
applied_time_extras: Dict[str, str]
|
applied_time_extras: dict[str, str]
|
||||||
apply_fetch_values_predicate: bool
|
apply_fetch_values_predicate: bool
|
||||||
columns: List[Column]
|
columns: list[Column]
|
||||||
datasource: Optional[BaseDatasource]
|
datasource: BaseDatasource | None
|
||||||
extras: Dict[str, Any]
|
extras: dict[str, Any]
|
||||||
filter: List[QueryObjectFilterClause]
|
filter: list[QueryObjectFilterClause]
|
||||||
from_dttm: Optional[datetime]
|
from_dttm: datetime | None
|
||||||
granularity: Optional[str]
|
granularity: str | None
|
||||||
inner_from_dttm: Optional[datetime]
|
inner_from_dttm: datetime | None
|
||||||
inner_to_dttm: Optional[datetime]
|
inner_to_dttm: datetime | None
|
||||||
is_rowcount: bool
|
is_rowcount: bool
|
||||||
is_timeseries: bool
|
is_timeseries: bool
|
||||||
metrics: Optional[List[Metric]]
|
metrics: list[Metric] | None
|
||||||
order_desc: bool
|
order_desc: bool
|
||||||
orderby: List[OrderBy]
|
orderby: list[OrderBy]
|
||||||
post_processing: List[Dict[str, Any]]
|
post_processing: list[dict[str, Any]]
|
||||||
result_type: Optional[ChartDataResultType]
|
result_type: ChartDataResultType | None
|
||||||
row_limit: Optional[int]
|
row_limit: int | None
|
||||||
row_offset: int
|
row_offset: int
|
||||||
series_columns: List[Column]
|
series_columns: list[Column]
|
||||||
series_limit: int
|
series_limit: int
|
||||||
series_limit_metric: Optional[Metric]
|
series_limit_metric: Metric | None
|
||||||
time_offsets: List[str]
|
time_offsets: list[str]
|
||||||
time_shift: Optional[str]
|
time_shift: str | None
|
||||||
time_range: Optional[str]
|
time_range: str | None
|
||||||
to_dttm: Optional[datetime]
|
to_dttm: datetime | None
|
||||||
|
|
||||||
def __init__( # pylint: disable=too-many-locals
|
def __init__( # pylint: disable=too-many-locals
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
annotation_layers: Optional[List[Dict[str, Any]]] = None,
|
annotation_layers: list[dict[str, Any]] | None = None,
|
||||||
applied_time_extras: Optional[Dict[str, str]] = None,
|
applied_time_extras: dict[str, str] | None = None,
|
||||||
apply_fetch_values_predicate: bool = False,
|
apply_fetch_values_predicate: bool = False,
|
||||||
columns: Optional[List[Column]] = None,
|
columns: list[Column] | None = None,
|
||||||
datasource: Optional[BaseDatasource] = None,
|
datasource: BaseDatasource | None = None,
|
||||||
extras: Optional[Dict[str, Any]] = None,
|
extras: dict[str, Any] | None = None,
|
||||||
filters: Optional[List[QueryObjectFilterClause]] = None,
|
filters: list[QueryObjectFilterClause] | None = None,
|
||||||
granularity: Optional[str] = None,
|
granularity: str | None = None,
|
||||||
is_rowcount: bool = False,
|
is_rowcount: bool = False,
|
||||||
is_timeseries: Optional[bool] = None,
|
is_timeseries: bool | None = None,
|
||||||
metrics: Optional[List[Metric]] = None,
|
metrics: list[Metric] | None = None,
|
||||||
order_desc: bool = True,
|
order_desc: bool = True,
|
||||||
orderby: Optional[List[OrderBy]] = None,
|
orderby: list[OrderBy] | None = None,
|
||||||
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
|
post_processing: list[dict[str, Any] | None] | None = None,
|
||||||
row_limit: Optional[int],
|
row_limit: int | None,
|
||||||
row_offset: Optional[int] = None,
|
row_offset: int | None = None,
|
||||||
series_columns: Optional[List[Column]] = None,
|
series_columns: list[Column] | None = None,
|
||||||
series_limit: int = 0,
|
series_limit: int = 0,
|
||||||
series_limit_metric: Optional[Metric] = None,
|
series_limit_metric: Metric | None = None,
|
||||||
time_range: Optional[str] = None,
|
time_range: str | None = None,
|
||||||
time_shift: Optional[str] = None,
|
time_shift: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
self._set_annotation_layers(annotation_layers)
|
self._set_annotation_layers(annotation_layers)
|
||||||
|
@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
self._move_deprecated_extra_fields(kwargs)
|
self._move_deprecated_extra_fields(kwargs)
|
||||||
|
|
||||||
def _set_annotation_layers(
|
def _set_annotation_layers(
|
||||||
self, annotation_layers: Optional[List[Dict[str, Any]]]
|
self, annotation_layers: list[dict[str, Any]] | None
|
||||||
) -> None:
|
) -> None:
|
||||||
self.annotation_layers = [
|
self.annotation_layers = [
|
||||||
layer
|
layer
|
||||||
|
@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
if layer["annotationType"] != "FORMULA"
|
if layer["annotationType"] != "FORMULA"
|
||||||
]
|
]
|
||||||
|
|
||||||
def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
|
def _set_is_timeseries(self, is_timeseries: bool | None) -> None:
|
||||||
# is_timeseries is True if time column is in either columns or groupby
|
# is_timeseries is True if time column is in either columns or groupby
|
||||||
# (both are dimensions)
|
# (both are dimensions)
|
||||||
self.is_timeseries = (
|
self.is_timeseries = (
|
||||||
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
|
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
|
def _set_metrics(self, metrics: list[Metric] | None = None) -> None:
|
||||||
# Support metric reference/definition in the format of
|
# Support metric reference/definition in the format of
|
||||||
# 1. 'metric_name' - name of predefined metric
|
# 1. 'metric_name' - name of predefined metric
|
||||||
# 2. { label: 'label_name' } - legacy format for a predefined metric
|
# 2. { label: 'label_name' } - legacy format for a predefined metric
|
||||||
|
@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
]
|
]
|
||||||
|
|
||||||
def _set_post_processing(
|
def _set_post_processing(
|
||||||
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
|
self, post_processing: list[dict[str, Any] | None] | None
|
||||||
) -> None:
|
) -> None:
|
||||||
post_processing = post_processing or []
|
post_processing = post_processing or []
|
||||||
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
|
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
|
||||||
|
|
||||||
def _init_series_columns(
|
def _init_series_columns(
|
||||||
self,
|
self,
|
||||||
series_columns: Optional[List[Column]],
|
series_columns: list[Column] | None,
|
||||||
metrics: Optional[List[Metric]],
|
metrics: list[Metric] | None,
|
||||||
is_timeseries: Optional[bool],
|
is_timeseries: bool | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if series_columns:
|
if series_columns:
|
||||||
self.series_columns = series_columns
|
self.series_columns = series_columns
|
||||||
|
@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
else:
|
else:
|
||||||
self.series_columns = []
|
self.series_columns = []
|
||||||
|
|
||||||
def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
|
def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None:
|
||||||
# rename deprecated fields
|
# rename deprecated fields
|
||||||
for field in DEPRECATED_FIELDS:
|
for field in DEPRECATED_FIELDS:
|
||||||
if field.old_name in kwargs:
|
if field.old_name in kwargs:
|
||||||
|
@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
)
|
)
|
||||||
setattr(self, field.new_name, value)
|
setattr(self, field.new_name, value)
|
||||||
|
|
||||||
def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
|
def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None:
|
||||||
# move deprecated extras fields to extras
|
# move deprecated extras fields to extras
|
||||||
for field in DEPRECATED_EXTRAS_FIELDS:
|
for field in DEPRECATED_EXTRAS_FIELDS:
|
||||||
if field.old_name in kwargs:
|
if field.old_name in kwargs:
|
||||||
|
@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
self.extras[field.new_name] = value
|
self.extras[field.new_name] = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_names(self) -> List[str]:
|
def metric_names(self) -> list[str]:
|
||||||
"""Return metrics names (labels), coerce adhoc metrics to strings."""
|
"""Return metrics names (labels), coerce adhoc metrics to strings."""
|
||||||
return get_metric_names(self.metrics or [])
|
return get_metric_names(self.metrics or [])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_names(self) -> List[str]:
|
def column_names(self) -> list[str]:
|
||||||
"""Return column names (labels). Gives priority to groupbys if both groupbys
|
"""Return column names (labels). Gives priority to groupbys if both groupbys
|
||||||
and metrics are non-empty, otherwise returns column labels."""
|
and metrics are non-empty, otherwise returns column labels."""
|
||||||
return get_column_names(self.columns)
|
return get_column_names(self.columns)
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
self, raise_exceptions: Optional[bool] = True
|
self, raise_exceptions: bool | None = True
|
||||||
) -> Optional[QueryObjectValidationError]:
|
) -> QueryObjectValidationError | None:
|
||||||
"""Validate query object"""
|
"""Validate query object"""
|
||||||
try:
|
try:
|
||||||
self._validate_there_are_no_missing_series()
|
self._validate_there_are_no_missing_series()
|
||||||
|
@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
query_object_dict = {
|
query_object_dict = {
|
||||||
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
|
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
|
||||||
"columns": self.columns,
|
"columns": self.columns,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from superset.common.chart_data import ChartDataResultType
|
from superset.common.chart_data import ChartDataResultType
|
||||||
from superset.common.query_object import QueryObject
|
from superset.common.query_object import QueryObject
|
||||||
|
@ -31,13 +31,13 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||||
_config: Dict[str, Any]
|
_config: dict[str, Any]
|
||||||
_datasource_dao: DatasourceDAO
|
_datasource_dao: DatasourceDAO
|
||||||
_session_maker: sessionmaker
|
_session_maker: sessionmaker
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app_configurations: Dict[str, Any],
|
app_configurations: dict[str, Any],
|
||||||
_datasource_dao: DatasourceDAO,
|
_datasource_dao: DatasourceDAO,
|
||||||
session_maker: sessionmaker,
|
session_maker: sessionmaker,
|
||||||
):
|
):
|
||||||
|
@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||||
def create( # pylint: disable=too-many-arguments
|
def create( # pylint: disable=too-many-arguments
|
||||||
self,
|
self,
|
||||||
parent_result_type: ChartDataResultType,
|
parent_result_type: ChartDataResultType,
|
||||||
datasource: Optional[DatasourceDict] = None,
|
datasource: DatasourceDict | None = None,
|
||||||
extras: Optional[Dict[str, Any]] = None,
|
extras: dict[str, Any] | None = None,
|
||||||
row_limit: Optional[int] = None,
|
row_limit: int | None = None,
|
||||||
time_range: Optional[str] = None,
|
time_range: str | None = None,
|
||||||
time_shift: Optional[str] = None,
|
time_shift: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QueryObject:
|
) -> QueryObject:
|
||||||
datasource_model_instance = None
|
datasource_model_instance = None
|
||||||
|
@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
def _process_extras( # pylint: disable=no-self-use
|
def _process_extras( # pylint: disable=no-self-use
|
||||||
self,
|
self,
|
||||||
extras: Optional[Dict[str, Any]],
|
extras: dict[str, Any] | None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
extras = extras or {}
|
extras = extras or {}
|
||||||
return extras
|
return extras
|
||||||
|
|
||||||
def _process_row_limit(
|
def _process_row_limit(
|
||||||
self, row_limit: Optional[int], result_type: ChartDataResultType
|
self, row_limit: int | None, result_type: ChartDataResultType
|
||||||
) -> int:
|
) -> int:
|
||||||
default_row_limit = (
|
default_row_limit = (
|
||||||
self._config["SAMPLES_ROW_LIMIT"]
|
self._config["SAMPLES_ROW_LIMIT"]
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import MetaData
|
from sqlalchemy import MetaData
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes
|
||||||
|
|
||||||
|
|
||||||
def add_types_to_charts(
|
def add_types_to_charts(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
slices = metadata.tables["slices"]
|
slices = metadata.tables["slices"]
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ def add_types_to_charts(
|
||||||
|
|
||||||
|
|
||||||
def add_types_to_dashboards(
|
def add_types_to_dashboards(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
dashboard_table = metadata.tables["dashboards"]
|
dashboard_table = metadata.tables["dashboards"]
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def add_types_to_dashboards(
|
||||||
|
|
||||||
|
|
||||||
def add_types_to_saved_queries(
|
def add_types_to_saved_queries(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
saved_query = metadata.tables["saved_query"]
|
saved_query = metadata.tables["saved_query"]
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ def add_types_to_saved_queries(
|
||||||
|
|
||||||
|
|
||||||
def add_types_to_datasets(
|
def add_types_to_datasets(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
tables = metadata.tables["tables"]
|
tables = metadata.tables["tables"]
|
||||||
|
|
||||||
|
@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None:
|
||||||
|
|
||||||
|
|
||||||
def add_owners_to_charts(
|
def add_owners_to_charts(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
slices = metadata.tables["slices"]
|
slices = metadata.tables["slices"]
|
||||||
|
|
||||||
|
@ -273,7 +273,7 @@ def add_owners_to_charts(
|
||||||
|
|
||||||
|
|
||||||
def add_owners_to_dashboards(
|
def add_owners_to_dashboards(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
dashboard_table = metadata.tables["dashboards"]
|
dashboard_table = metadata.tables["dashboards"]
|
||||||
|
|
||||||
|
@ -309,7 +309,7 @@ def add_owners_to_dashboards(
|
||||||
|
|
||||||
|
|
||||||
def add_owners_to_saved_queries(
|
def add_owners_to_saved_queries(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
saved_query = metadata.tables["saved_query"]
|
saved_query = metadata.tables["saved_query"]
|
||||||
|
|
||||||
|
@ -345,7 +345,7 @@ def add_owners_to_saved_queries(
|
||||||
|
|
||||||
|
|
||||||
def add_owners_to_datasets(
|
def add_owners_to_datasets(
|
||||||
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
|
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
tables = metadata.tables["tables"]
|
tables = metadata.tables["tables"]
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Any, List, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||||
def left_join_df(
|
def left_join_df(
|
||||||
left_df: pd.DataFrame,
|
left_df: pd.DataFrame,
|
||||||
right_df: pd.DataFrame,
|
right_df: pd.DataFrame,
|
||||||
join_keys: List[str],
|
join_keys: list[str],
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
|
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
|
||||||
df.reset_index(inplace=True)
|
df.reset_index(inplace=True)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
@ -37,7 +37,7 @@ config = app.config
|
||||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_cache: Dict[CacheRegion, Cache] = {
|
_cache: dict[CacheRegion, Cache] = {
|
||||||
CacheRegion.DEFAULT: cache_manager.cache,
|
CacheRegion.DEFAULT: cache_manager.cache,
|
||||||
CacheRegion.DATA: cache_manager.data_cache,
|
CacheRegion.DATA: cache_manager.data_cache,
|
||||||
}
|
}
|
||||||
|
@ -53,17 +53,17 @@ class QueryCacheManager:
|
||||||
self,
|
self,
|
||||||
df: DataFrame = DataFrame(),
|
df: DataFrame = DataFrame(),
|
||||||
query: str = "",
|
query: str = "",
|
||||||
annotation_data: Optional[Dict[str, Any]] = None,
|
annotation_data: dict[str, Any] | None = None,
|
||||||
applied_template_filters: Optional[List[str]] = None,
|
applied_template_filters: list[str] | None = None,
|
||||||
applied_filter_columns: Optional[List[Column]] = None,
|
applied_filter_columns: list[Column] | None = None,
|
||||||
rejected_filter_columns: Optional[List[Column]] = None,
|
rejected_filter_columns: list[Column] | None = None,
|
||||||
status: Optional[str] = None,
|
status: str | None = None,
|
||||||
error_message: Optional[str] = None,
|
error_message: str | None = None,
|
||||||
is_loaded: bool = False,
|
is_loaded: bool = False,
|
||||||
stacktrace: Optional[str] = None,
|
stacktrace: str | None = None,
|
||||||
is_cached: Optional[bool] = None,
|
is_cached: bool | None = None,
|
||||||
cache_dttm: Optional[str] = None,
|
cache_dttm: str | None = None,
|
||||||
cache_value: Optional[Dict[str, Any]] = None,
|
cache_value: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.df = df
|
self.df = df
|
||||||
self.query = query
|
self.query = query
|
||||||
|
@ -85,10 +85,10 @@ class QueryCacheManager:
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
query_result: QueryResult,
|
query_result: QueryResult,
|
||||||
annotation_data: Optional[Dict[str, Any]] = None,
|
annotation_data: dict[str, Any] | None = None,
|
||||||
force_query: Optional[bool] = False,
|
force_query: bool | None = False,
|
||||||
timeout: Optional[int] = None,
|
timeout: int | None = None,
|
||||||
datasource_uid: Optional[str] = None,
|
datasource_uid: str | None = None,
|
||||||
region: CacheRegion = CacheRegion.DEFAULT,
|
region: CacheRegion = CacheRegion.DEFAULT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -136,11 +136,11 @@ class QueryCacheManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(
|
def get(
|
||||||
cls,
|
cls,
|
||||||
key: Optional[str],
|
key: str | None,
|
||||||
region: CacheRegion = CacheRegion.DEFAULT,
|
region: CacheRegion = CacheRegion.DEFAULT,
|
||||||
force_query: Optional[bool] = False,
|
force_query: bool | None = False,
|
||||||
force_cached: Optional[bool] = False,
|
force_cached: bool | None = False,
|
||||||
) -> "QueryCacheManager":
|
) -> QueryCacheManager:
|
||||||
"""
|
"""
|
||||||
Initialize QueryCacheManager by query-cache key
|
Initialize QueryCacheManager by query-cache key
|
||||||
"""
|
"""
|
||||||
|
@ -190,10 +190,10 @@ class QueryCacheManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set(
|
def set(
|
||||||
key: Optional[str],
|
key: str | None,
|
||||||
value: Dict[str, Any],
|
value: dict[str, Any],
|
||||||
timeout: Optional[int] = None,
|
timeout: int | None = None,
|
||||||
datasource_uid: Optional[str] = None,
|
datasource_uid: str | None = None,
|
||||||
region: CacheRegion = CacheRegion.DEFAULT,
|
region: CacheRegion = CacheRegion.DEFAULT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -204,7 +204,7 @@ class QueryCacheManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete(
|
def delete(
|
||||||
key: Optional[str],
|
key: str | None,
|
||||||
region: CacheRegion = CacheRegion.DEFAULT,
|
region: CacheRegion = CacheRegion.DEFAULT,
|
||||||
) -> None:
|
) -> None:
|
||||||
if key:
|
if key:
|
||||||
|
@ -212,7 +212,7 @@ class QueryCacheManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has(
|
def has(
|
||||||
key: Optional[str],
|
key: str | None,
|
||||||
region: CacheRegion = CacheRegion.DEFAULT,
|
region: CacheRegion = CacheRegion.DEFAULT,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return bool(_cache[region].get(key)) if key else False
|
return bool(_cache[region].get(key)) if key else False
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast, Dict, Optional, Tuple
|
from typing import Any, cast
|
||||||
|
|
||||||
from superset import app
|
from superset import app
|
||||||
from superset.common.query_object import QueryObject
|
from superset.common.query_object import QueryObject
|
||||||
|
@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until
|
||||||
|
|
||||||
|
|
||||||
def get_since_until_from_time_range(
|
def get_since_until_from_time_range(
|
||||||
time_range: Optional[str] = None,
|
time_range: str | None = None,
|
||||||
time_shift: Optional[str] = None,
|
time_shift: str | None = None,
|
||||||
extras: Optional[Dict[str, Any]] = None,
|
extras: dict[str, Any] | None = None,
|
||||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
return get_since_until(
|
return get_since_until(
|
||||||
relative_start=(extras or {}).get(
|
relative_start=(extras or {}).get(
|
||||||
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
|
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
|
||||||
|
@ -45,7 +45,7 @@ def get_since_until_from_time_range(
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
def get_since_until_from_query_object(
|
def get_since_until_from_query_object(
|
||||||
query_object: QueryObject,
|
query_object: QueryObject,
|
||||||
) -> Tuple[Optional[datetime], Optional[datetime]]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
"""
|
"""
|
||||||
this function will return since and until by tuple if
|
this function will return since and until by tuple if
|
||||||
1) the time_range is in the query object.
|
1) the time_range is in the query object.
|
||||||
|
|
|
@ -33,20 +33,7 @@ import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from typing import (
|
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
TypedDict,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from cachelib.base import BaseCache
|
from cachelib.base import BaseCache
|
||||||
|
@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename(
|
||||||
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
|
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
|
||||||
|
|
||||||
|
|
||||||
def _try_json_readversion(filepath: str) -> Optional[str]:
|
def _try_json_readversion(filepath: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r") as f:
|
with open(filepath) as f:
|
||||||
return json.load(f).get("version")
|
return json.load(f).get("version")
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
|
def _try_json_readsha(filepath: str, length: int) -> str | None:
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r") as f:
|
with open(filepath) as f:
|
||||||
return json.load(f).get("GIT_SHA")[:length]
|
return json.load(f).get("GIT_SHA")[:length]
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
return None
|
return None
|
||||||
|
@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False
|
||||||
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
|
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
|
||||||
|
|
||||||
# Configuration for scheduling queries from SQL Lab.
|
# Configuration for scheduling queries from SQL Lab.
|
||||||
SCHEDULED_QUERIES: Dict[str, Any] = {}
|
SCHEDULED_QUERIES: dict[str, Any] = {}
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# GLOBALS FOR APP Builder
|
# GLOBALS FOR APP Builder
|
||||||
|
@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None
|
||||||
LOGO_TOOLTIP = ""
|
LOGO_TOOLTIP = ""
|
||||||
|
|
||||||
# Specify any text that should appear to the right of the logo
|
# Specify any text that should appear to the right of the logo
|
||||||
LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = ""
|
LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
|
||||||
|
|
||||||
# Enables SWAGGER UI for superset openapi spec
|
# Enables SWAGGER UI for superset openapi spec
|
||||||
# ex: http://localhost:8080/swagger/v1
|
# ex: http://localhost:8080/swagger/v1
|
||||||
|
@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB
|
||||||
# Grant public role the same set of permissions as for a selected builtin role.
|
# Grant public role the same set of permissions as for a selected builtin role.
|
||||||
# This is useful if one wants to enable anonymous users to view
|
# This is useful if one wants to enable anonymous users to view
|
||||||
# dashboards. Explicit grant on specific datasets is still required.
|
# dashboards. Explicit grant on specific datasets is still required.
|
||||||
PUBLIC_ROLE_LIKE: Optional[str] = None
|
PUBLIC_ROLE_LIKE: str | None = None
|
||||||
|
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
# Babel config for translations
|
# Babel config for translations
|
||||||
|
@ -390,8 +377,8 @@ LANGUAGES = {}
|
||||||
class D3Format(TypedDict, total=False):
|
class D3Format(TypedDict, total=False):
|
||||||
decimal: str
|
decimal: str
|
||||||
thousands: str
|
thousands: str
|
||||||
grouping: List[int]
|
grouping: list[int]
|
||||||
currency: List[str]
|
currency: list[str]
|
||||||
|
|
||||||
|
|
||||||
D3_FORMAT: D3Format = {}
|
D3_FORMAT: D3Format = {}
|
||||||
|
@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {}
|
||||||
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
|
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
|
||||||
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
|
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
|
||||||
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
|
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
|
||||||
DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||||
# Experimental feature introducing a client (browser) cache
|
# Experimental feature introducing a client (browser) cache
|
||||||
"CLIENT_CACHE": False, # deprecated
|
"CLIENT_CACHE": False, # deprecated
|
||||||
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
|
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
|
||||||
|
@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update(
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is merely a default.
|
# This is merely a default.
|
||||||
FEATURE_FLAGS: Dict[str, bool] = {}
|
FEATURE_FLAGS: dict[str, bool] = {}
|
||||||
|
|
||||||
# A function that receives a dict of all feature flags
|
# A function that receives a dict of all feature flags
|
||||||
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
|
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
|
||||||
|
@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {}
|
||||||
# if hasattr(g, "user") and g.user.is_active:
|
# if hasattr(g, "user") and g.user.is_active:
|
||||||
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
|
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
|
||||||
# return feature_flags_dict
|
# return feature_flags_dict
|
||||||
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
|
GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None
|
||||||
# A function that receives a feature flag name and an optional default value.
|
# A function that receives a feature flag name and an optional default value.
|
||||||
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
|
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
|
||||||
# evaluation of all feature flags when just evaluating a single one.
|
# evaluation of all feature flags when just evaluating a single one.
|
||||||
|
@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] =
|
||||||
# Note that the default `get_feature_flags` will evaluate each feature with this
|
# Note that the default `get_feature_flags` will evaluate each feature with this
|
||||||
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
|
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
|
||||||
# and IS_FEATURE_ENABLED_FUNC in conjunction.
|
# and IS_FEATURE_ENABLED_FUNC in conjunction.
|
||||||
IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
|
IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
|
||||||
# A function that expands/overrides the frontend `bootstrap_data.common` object.
|
# A function that expands/overrides the frontend `bootstrap_data.common` object.
|
||||||
# Can be used to implement custom frontend functionality,
|
# Can be used to implement custom frontend functionality,
|
||||||
# or dynamically change certain configs.
|
# or dynamically change certain configs.
|
||||||
|
@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
|
||||||
# Takes as a parameter the common bootstrap payload before transformations.
|
# Takes as a parameter the common bootstrap payload before transformations.
|
||||||
# Returns a dict containing data that should be added or overridden to the payload.
|
# Returns a dict containing data that should be added or overridden to the payload.
|
||||||
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
||||||
[Dict[str, Any]], Dict[str, Any]
|
[dict[str, Any]], dict[str, Any]
|
||||||
] = lambda data: {} # default: empty dict
|
] = lambda data: {} # default: empty dict
|
||||||
|
|
||||||
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
|
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
|
||||||
|
@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
||||||
# }]
|
# }]
|
||||||
|
|
||||||
# This is merely a default
|
# This is merely a default
|
||||||
EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# THEME_OVERRIDES is used for adding custom theme to superset
|
# THEME_OVERRIDES is used for adding custom theme to superset
|
||||||
# example code for "My theme" custom scheme
|
# example code for "My theme" custom scheme
|
||||||
|
@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
|
|
||||||
THEME_OVERRIDES: Dict[str, Any] = {}
|
THEME_OVERRIDES: dict[str, Any] = {}
|
||||||
|
|
||||||
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
|
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
|
||||||
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
|
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
|
||||||
|
@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
|
||||||
# }]
|
# }]
|
||||||
|
|
||||||
# This is merely a default
|
# This is merely a default
|
||||||
EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
# Thumbnail config (behind feature flag)
|
# Thumbnail config (behind feature flag)
|
||||||
|
@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
|
||||||
# `superset.tasks.types.ExecutorType` for a full list of executor options.
|
# `superset.tasks.types.ExecutorType` for a full list of executor options.
|
||||||
# To always use a fixed user account, use the following configuration:
|
# To always use a fixed user account, use the following configuration:
|
||||||
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
|
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
|
||||||
THUMBNAIL_SELENIUM_USER: Optional[str] = "admin"
|
THUMBNAIL_SELENIUM_USER: str | None = "admin"
|
||||||
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
|
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
|
||||||
|
|
||||||
# By default, thumbnail digests are calculated based on various parameters in the
|
# By default, thumbnail digests are calculated based on various parameters in the
|
||||||
|
@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
|
||||||
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
|
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
|
||||||
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
|
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
|
||||||
# and return the final digest string:
|
# and return the final digest string:
|
||||||
THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[
|
THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | (
|
||||||
Callable[[Dashboard, ExecutorType, str], str]
|
Callable[[Dashboard, ExecutorType, str], str]
|
||||||
] = None
|
) = None
|
||||||
THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None
|
THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None
|
||||||
|
|
||||||
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
|
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
|
||||||
"CACHE_TYPE": "NullCache",
|
"CACHE_TYPE": "NullCache",
|
||||||
|
@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False
|
||||||
|
|
||||||
# CORS Options
|
# CORS Options
|
||||||
ENABLE_CORS = False
|
ENABLE_CORS = False
|
||||||
CORS_OPTIONS: Dict[Any, Any] = {}
|
CORS_OPTIONS: dict[Any, Any] = {}
|
||||||
|
|
||||||
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
|
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
|
||||||
# Disabling this option is not recommended for security reasons. If you wish to allow
|
# Disabling this option is not recommended for security reasons. If you wish to allow
|
||||||
|
@ -736,7 +723,7 @@ HTML_SANITIZATION = True
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
# Be careful when extending the default schema to avoid XSS attacks.
|
# Be careful when extending the default schema to avoid XSS attacks.
|
||||||
HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {}
|
HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {}
|
||||||
|
|
||||||
# Chrome allows up to 6 open connections per domain at a time. When there are more
|
# Chrome allows up to 6 open connections per domain at a time. When there are more
|
||||||
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
|
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
|
||||||
|
@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"}
|
||||||
# time grains in superset/db_engine_specs/base.py).
|
# time grains in superset/db_engine_specs/base.py).
|
||||||
# For example: to disable 1 second time grain:
|
# For example: to disable 1 second time grain:
|
||||||
# TIME_GRAIN_DENYLIST = ['PT1S']
|
# TIME_GRAIN_DENYLIST = ['PT1S']
|
||||||
TIME_GRAIN_DENYLIST: List[str] = []
|
TIME_GRAIN_DENYLIST: list[str] = []
|
||||||
|
|
||||||
# Additional time grains to be supported using similar definitions as in
|
# Additional time grains to be supported using similar definitions as in
|
||||||
# superset/db_engine_specs/base.py.
|
# superset/db_engine_specs/base.py.
|
||||||
# For example: To add a new 2 second time grain:
|
# For example: To add a new 2 second time grain:
|
||||||
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
|
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
|
||||||
TIME_GRAIN_ADDONS: Dict[str, str] = {}
|
TIME_GRAIN_ADDONS: dict[str, str] = {}
|
||||||
|
|
||||||
# Implementation of additional time grains per engine.
|
# Implementation of additional time grains per engine.
|
||||||
# The column to be truncated is denoted `{col}` in the expression.
|
# The column to be truncated is denoted `{col}` in the expression.
|
||||||
|
@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {}
|
||||||
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
|
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
|
TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {}
|
||||||
|
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
# List of viz_types not allowed in your environment
|
# List of viz_types not allowed in your environment
|
||||||
|
@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
|
||||||
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
|
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
|
|
||||||
VIZ_TYPE_DENYLIST: List[str] = []
|
VIZ_TYPE_DENYLIST: list[str] = []
|
||||||
|
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
# Modules, datasources and middleware to be registered
|
# Modules, datasources and middleware to be registered
|
||||||
|
@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
|
||||||
("superset.connectors.sqla.models", ["SqlaTable"]),
|
("superset.connectors.sqla.models", ["SqlaTable"]),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
|
ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {}
|
||||||
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
|
ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = []
|
||||||
|
|
||||||
# 1) https://docs.python-guide.org/writing/logging/
|
# 1) https://docs.python-guide.org/writing/logging/
|
||||||
# 2) https://docs.python.org/2/library/logging.config.html
|
# 2) https://docs.python.org/2/library/logging.config.html
|
||||||
|
@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name
|
||||||
# within the app
|
# within the app
|
||||||
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
|
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
|
||||||
# override anything set within the app
|
# override anything set within the app
|
||||||
DEFAULT_HTTP_HEADERS: Dict[str, Any] = {}
|
DEFAULT_HTTP_HEADERS: dict[str, Any] = {}
|
||||||
OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {}
|
OVERRIDE_HTTP_HEADERS: dict[str, Any] = {}
|
||||||
HTTP_HEADERS: Dict[str, Any] = {}
|
HTTP_HEADERS: dict[str, Any] = {}
|
||||||
|
|
||||||
# The db id here results in selecting this one as a default in SQL Lab
|
# The db id here results in selecting this one as a default in SQL Lab
|
||||||
DEFAULT_DB_ID = None
|
DEFAULT_DB_ID = None
|
||||||
|
@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
|
||||||
# return out
|
# return out
|
||||||
#
|
#
|
||||||
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
|
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
|
||||||
QUERY_COST_FORMATTERS_BY_ENGINE: Dict[
|
QUERY_COST_FORMATTERS_BY_ENGINE: dict[
|
||||||
str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
|
str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
# Flag that controls if limit should be enforced on the CTA (create table as queries).
|
# Flag that controls if limit should be enforced on the CTA (create table as queries).
|
||||||
|
@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False
|
||||||
# else:
|
# else:
|
||||||
# return f'tmp_{schema}'
|
# return f'tmp_{schema}'
|
||||||
# Function accepts database object, user object, schema name and sql that will be run.
|
# Function accepts database object, user object, schema name and sql that will be run.
|
||||||
SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
|
SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | (
|
||||||
Callable[[Database, models.User, str, str], str]
|
Callable[[Database, models.User, str, str], str]
|
||||||
] = None
|
) = None
|
||||||
|
|
||||||
# If enabled, it can be used to store the results of long-running queries
|
# If enabled, it can be used to store the results of long-running queries
|
||||||
# in SQL Lab by using the "Run Async" button/feature
|
# in SQL Lab by using the "Run Async" button/feature
|
||||||
RESULTS_BACKEND: Optional[BaseCache] = None
|
RESULTS_BACKEND: BaseCache | None = None
|
||||||
|
|
||||||
# Use PyArrow and MessagePack for async query results serialization,
|
# Use PyArrow and MessagePack for async query results serialization,
|
||||||
# rather than JSON. This feature requires additional testing from the
|
# rather than JSON. This feature requires additional testing from the
|
||||||
|
@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
|
||||||
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
||||||
database: Database,
|
database: Database,
|
||||||
user: models.User, # pylint: disable=unused-argument
|
user: models.User, # pylint: disable=unused-argument
|
||||||
schema: Optional[str],
|
schema: str | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Note the final empty path enforces a trailing slash.
|
# Note the final empty path enforces a trailing slash.
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
|
@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
||||||
|
|
||||||
# The namespace within hive where the tables created from
|
# The namespace within hive where the tables created from
|
||||||
# uploading CSVs will be stored.
|
# uploading CSVs will be stored.
|
||||||
UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
|
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
|
||||||
|
|
||||||
# Function that computes the allowed schemas for the CSV uploads.
|
# Function that computes the allowed schemas for the CSV uploads.
|
||||||
# Allowed schemas will be a union of schemas_allowed_for_file_upload
|
# Allowed schemas will be a union of schemas_allowed_for_file_upload
|
||||||
# db configuration and a result of this function.
|
# db configuration and a result of this function.
|
||||||
|
|
||||||
# mypy doesn't catch that if case ensures list content being always str
|
# mypy doesn't catch that if case ensures list content being always str
|
||||||
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = (
|
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = (
|
||||||
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
|
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
|
||||||
if UPLOADED_CSV_HIVE_NAMESPACE
|
if UPLOADED_CSV_HIVE_NAMESPACE
|
||||||
else []
|
else []
|
||||||
|
@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
|
||||||
# It's important to make sure that the objects exposed (as well as objects attached
|
# It's important to make sure that the objects exposed (as well as objects attached
|
||||||
# to those objets) are harmless. We recommend only exposing simple/pure functions that
|
# to those objets) are harmless. We recommend only exposing simple/pure functions that
|
||||||
# return native types.
|
# return native types.
|
||||||
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {}
|
||||||
|
|
||||||
# A dictionary of macro template processors (by engine) that gets merged into global
|
# A dictionary of macro template processors (by engine) that gets merged into global
|
||||||
# template processors. The existing template processors get updated with this
|
# template processors. The existing template processors get updated with this
|
||||||
|
@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
||||||
# dictionary. The customized addons don't necessarily need to use Jinja templating
|
# dictionary. The customized addons don't necessarily need to use Jinja templating
|
||||||
# language. This allows you to define custom logic to process templates on a per-engine
|
# language. This allows you to define custom logic to process templates on a per-engine
|
||||||
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
|
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
|
||||||
CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
|
CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {}
|
||||||
|
|
||||||
# Roles that are controlled by the API / Superset and should not be changes
|
# Roles that are controlled by the API / Superset and should not be changes
|
||||||
# by humans.
|
# by humans.
|
||||||
|
@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
|
||||||
|
|
||||||
# Integrate external Blueprints to the app by passing them to your
|
# Integrate external Blueprints to the app by passing them to your
|
||||||
# configuration. These blueprints will get integrated in the app
|
# configuration. These blueprints will get integrated in the app
|
||||||
BLUEPRINTS: List[Blueprint] = []
|
BLUEPRINTS: list[Blueprint] = []
|
||||||
|
|
||||||
# Provide a callable that receives a tracking_url and returns another
|
# Provide a callable that receives a tracking_url and returns another
|
||||||
# URL. This is used to translate internal Hadoop job tracker URL
|
# URL. This is used to translate internal Hadoop job tracker URL
|
||||||
|
@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url
|
||||||
|
|
||||||
|
|
||||||
# customize the polling time of each engine
|
# customize the polling time of each engine
|
||||||
DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {}
|
DB_POLL_INTERVAL_SECONDS: dict[str, int] = {}
|
||||||
|
|
||||||
# Interval between consecutive polls when using Presto Engine
|
# Interval between consecutive polls when using Presto Engine
|
||||||
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
|
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
|
||||||
|
@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
|
||||||
# "another_auth_method": auth_method,
|
# "another_auth_method": auth_method,
|
||||||
# },
|
# },
|
||||||
# }
|
# }
|
||||||
ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
|
ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||||
|
|
||||||
# The id of a template dashboard that should be copied to every new user
|
# The id of a template dashboard that should be copied to every new user
|
||||||
DASHBOARD_TEMPLATE_ID = None
|
DASHBOARD_TEMPLATE_ID = None
|
||||||
|
@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
|
||||||
# Owners, filters for created_by, etc.
|
# Owners, filters for created_by, etc.
|
||||||
# The users can also be excluded by overriding the get_exclude_users_from_lists method
|
# The users can also be excluded by overriding the get_exclude_users_from_lists method
|
||||||
# in security manager
|
# in security manager
|
||||||
EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None
|
EXCLUDE_USERS_FROM_LISTS: list[str] | None = None
|
||||||
|
|
||||||
# For database connections, this dictionary will remove engines from the available
|
# For database connections, this dictionary will remove engines from the available
|
||||||
# list/dropdown if you do not want these dbs to show as available.
|
# list/dropdown if you do not want these dbs to show as available.
|
||||||
# The available list is generated by driver installed, and some engines have multiple
|
# The available list is generated by driver installed, and some engines have multiple
|
||||||
# drivers.
|
# drivers.
|
||||||
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
|
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
|
||||||
DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {}
|
DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {}
|
||||||
|
|
||||||
# This auth provider is used by background (offline) tasks that need to access
|
# This auth provider is used by background (offline) tasks that need to access
|
||||||
# protected resources. Can be overridden by end users in order to support
|
# protected resources. Can be overridden by end users in order to support
|
||||||
|
@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
|
||||||
# ExecutorType.OWNER,
|
# ExecutorType.OWNER,
|
||||||
# ExecutorType.SELENIUM,
|
# ExecutorType.SELENIUM,
|
||||||
# ]
|
# ]
|
||||||
ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER]
|
ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER]
|
||||||
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
|
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
|
||||||
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
|
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
|
||||||
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
|
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
|
||||||
|
@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
|
||||||
EMAIL_REPORTS_CTA = "Explore in Superset"
|
EMAIL_REPORTS_CTA = "Explore in Superset"
|
||||||
|
|
||||||
# Slack API token for the superset reports, either string or callable
|
# Slack API token for the superset reports, either string or callable
|
||||||
SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None
|
SLACK_API_TOKEN: Callable[[], str] | str | None = None
|
||||||
SLACK_PROXY = None
|
SLACK_PROXY = None
|
||||||
|
|
||||||
# The webdriver to use for generating reports. Use one of the following
|
# The webdriver to use for generating reports. Use one of the following
|
||||||
|
@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = {
|
||||||
WEBDRIVER_AUTH_FUNC = None
|
WEBDRIVER_AUTH_FUNC = None
|
||||||
|
|
||||||
# Any config options to be passed as-is to the webdriver
|
# Any config options to be passed as-is to the webdriver
|
||||||
WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"}
|
WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"}
|
||||||
|
|
||||||
# Additional args to be passed as arguments to the config object
|
# Additional args to be passed as arguments to the config object
|
||||||
# Note: If using Chrome, you'll want to add the "--marionette" arg.
|
# Note: If using Chrome, you'll want to add the "--marionette" arg.
|
||||||
|
@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = {
|
||||||
# displayed prominently in the "Add Database" dialog. You should
|
# displayed prominently in the "Add Database" dialog. You should
|
||||||
# use the "engine_name" attribute of the corresponding DB engine spec
|
# use the "engine_name" attribute of the corresponding DB engine spec
|
||||||
# in `superset/db_engine_specs/`.
|
# in `superset/db_engine_specs/`.
|
||||||
PREFERRED_DATABASES: List[str] = [
|
PREFERRED_DATABASES: list[str] = [
|
||||||
"PostgreSQL",
|
"PostgreSQL",
|
||||||
"Presto",
|
"Presto",
|
||||||
"MySQL",
|
"MySQL",
|
||||||
|
@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = {
|
||||||
#
|
#
|
||||||
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
|
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
|
||||||
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
|
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
|
||||||
SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax"
|
SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax"
|
||||||
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
|
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
|
||||||
SESSION_PROTECTION = "strong"
|
SESSION_PROTECTION = "strong"
|
||||||
|
|
||||||
|
@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
|
||||||
# Path used to store SSL certificates that are generated when using custom certs.
|
# Path used to store SSL certificates that are generated when using custom certs.
|
||||||
# Defaults to temporary directory.
|
# Defaults to temporary directory.
|
||||||
# Example: SSL_CERT_PATH = "/certs"
|
# Example: SSL_CERT_PATH = "/certs"
|
||||||
SSL_CERT_PATH: Optional[str] = None
|
SSL_CERT_PATH: str | None = None
|
||||||
|
|
||||||
# SQLA table mutator, every time we fetch the metadata for a certain table
|
# SQLA table mutator, every time we fetch the metadata for a certain table
|
||||||
# (superset.connectors.sqla.models.SqlaTable), we call this hook
|
# (superset.connectors.sqla.models.SqlaTable), we call this hook
|
||||||
|
@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
|
||||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
|
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
|
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
|
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[
|
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
|
||||||
Literal["None", "Lax", "Strict"]
|
Literal["None", "Lax", "Strict"]
|
||||||
] = None
|
) = None
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
|
||||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
|
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
|
||||||
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
|
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
|
||||||
|
@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256"
|
||||||
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
|
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
|
||||||
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
|
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
|
||||||
# Guest token audience for the embedded superset, either string or callable
|
# Guest token audience for the embedded superset, either string or callable
|
||||||
GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
|
GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None
|
||||||
|
|
||||||
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
|
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
|
||||||
# be memoized to aid with performance, i.e.,
|
# be memoized to aid with performance, i.e.,
|
||||||
|
@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
|
||||||
# cache_manager.cache.delete_memoized(func)
|
# cache_manager.cache.delete_memoized(func)
|
||||||
# cache_manager.cache.set(name, code, timeout=0)
|
# cache_manager.cache.set(name, code, timeout=0)
|
||||||
#
|
#
|
||||||
DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None
|
DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None
|
||||||
|
|
||||||
# Do not show user info or profile in the menu
|
# Do not show user info or profile in the menu
|
||||||
MENU_HIDE_USER_INFO = False
|
MENU_HIDE_USER_INFO = False
|
||||||
|
@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False
|
||||||
ENABLE_BROAD_ACTIVITY_ACCESS = True
|
ENABLE_BROAD_ACTIVITY_ACCESS = True
|
||||||
|
|
||||||
# the advanced data type key should correspond to that set in the column metadata
|
# the advanced data type key should correspond to that set in the column metadata
|
||||||
ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
|
ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = {
|
||||||
"internet_address": internet_address,
|
"internet_address": internet_address,
|
||||||
"port": internet_port,
|
"port": internet_port,
|
||||||
}
|
}
|
||||||
|
@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
|
||||||
# "Xyz",
|
# "Xyz",
|
||||||
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
|
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
|
||||||
# )
|
# )
|
||||||
WELCOME_PAGE_LAST_TAB: Union[
|
WELCOME_PAGE_LAST_TAB: (
|
||||||
Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]]
|
Literal["examples", "all"] | tuple[str, list[dict[str, Any]]]
|
||||||
] = "all"
|
) = "all"
|
||||||
|
|
||||||
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
|
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
|
||||||
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)
|
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)
|
||||||
|
|
|
@ -16,21 +16,12 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Hashable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Any, TYPE_CHECKING
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Hashable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
@ -89,23 +80,23 @@ class BaseDatasource(
|
||||||
# ---------------------------------------------------------------
|
# ---------------------------------------------------------------
|
||||||
# class attributes to define when deriving BaseDatasource
|
# class attributes to define when deriving BaseDatasource
|
||||||
# ---------------------------------------------------------------
|
# ---------------------------------------------------------------
|
||||||
__tablename__: Optional[str] = None # {connector_name}_datasource
|
__tablename__: str | None = None # {connector_name}_datasource
|
||||||
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
|
baselink: str | None = None # url portion pointing to ModelView endpoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_class(self) -> Type["BaseColumn"]:
|
def column_class(self) -> type[BaseColumn]:
|
||||||
# link to derivative of BaseColumn
|
# link to derivative of BaseColumn
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_class(self) -> Type["BaseMetric"]:
|
def metric_class(self) -> type[BaseMetric]:
|
||||||
# link to derivative of BaseMetric
|
# link to derivative of BaseMetric
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
owner_class: Optional[User] = None
|
owner_class: User | None = None
|
||||||
|
|
||||||
# Used to do code highlighting when displaying the query in the UI
|
# Used to do code highlighting when displaying the query in the UI
|
||||||
query_language: Optional[str] = None
|
query_language: str | None = None
|
||||||
|
|
||||||
# Only some datasources support Row Level Security
|
# Only some datasources support Row Level Security
|
||||||
is_rls_supported: bool = False
|
is_rls_supported: bool = False
|
||||||
|
@ -131,9 +122,9 @@ class BaseDatasource(
|
||||||
is_managed_externally = Column(Boolean, nullable=False, default=False)
|
is_managed_externally = Column(Boolean, nullable=False, default=False)
|
||||||
external_url = Column(Text, nullable=True)
|
external_url = Column(Text, nullable=True)
|
||||||
|
|
||||||
sql: Optional[str] = None
|
sql: str | None = None
|
||||||
owners: List[User]
|
owners: list[User]
|
||||||
update_from_object_fields: List[str]
|
update_from_object_fields: list[str]
|
||||||
|
|
||||||
extra_import_fields = ["is_managed_externally", "external_url"]
|
extra_import_fields = ["is_managed_externally", "external_url"]
|
||||||
|
|
||||||
|
@ -142,7 +133,7 @@ class BaseDatasource(
|
||||||
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
|
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def owners_data(self) -> List[Dict[str, Any]]:
|
def owners_data(self) -> list[dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"first_name": o.first_name,
|
"first_name": o.first_name,
|
||||||
|
@ -167,8 +158,8 @@ class BaseDatasource(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
columns: List["BaseColumn"] = []
|
columns: list[BaseColumn] = []
|
||||||
metrics: List["BaseMetric"] = []
|
metrics: list[BaseMetric] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
|
@ -180,11 +171,11 @@ class BaseDatasource(
|
||||||
return f"{self.id}__{self.type}"
|
return f"{self.id}__{self.type}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_names(self) -> List[str]:
|
def column_names(self) -> list[str]:
|
||||||
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
|
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns_types(self) -> Dict[str, str]:
|
def columns_types(self) -> dict[str, str]:
|
||||||
return {c.column_name: c.type for c in self.columns}
|
return {c.column_name: c.type for c in self.columns}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -196,26 +187,26 @@ class BaseDatasource(
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connection(self) -> Optional[str]:
|
def connection(self) -> str | None:
|
||||||
"""String representing the context of the Datasource"""
|
"""String representing the context of the Datasource"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> Optional[str]:
|
def schema(self) -> str | None:
|
||||||
"""String representing the schema of the Datasource (if it applies)"""
|
"""String representing the schema of the Datasource (if it applies)"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filterable_column_names(self) -> List[str]:
|
def filterable_column_names(self) -> list[str]:
|
||||||
return sorted([c.column_name for c in self.columns if c.filterable])
|
return sorted([c.column_name for c in self.columns if c.filterable])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dttm_cols(self) -> List[str]:
|
def dttm_cols(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return "/{}/edit/{}".format(self.baselink, self.id)
|
return f"/{self.baselink}/edit/{self.id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def explore_url(self) -> str:
|
def explore_url(self) -> str:
|
||||||
|
@ -224,10 +215,10 @@ class BaseDatasource(
|
||||||
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
|
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def column_formats(self) -> Dict[str, Optional[str]]:
|
def column_formats(self) -> dict[str, str | None]:
|
||||||
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
|
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
|
||||||
|
|
||||||
def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
|
def add_missing_metrics(self, metrics: list[BaseMetric]) -> None:
|
||||||
existing_metrics = {m.metric_name for m in self.metrics}
|
existing_metrics = {m.metric_name for m in self.metrics}
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
if metric.metric_name not in existing_metrics:
|
if metric.metric_name not in existing_metrics:
|
||||||
|
@ -235,7 +226,7 @@ class BaseDatasource(
|
||||||
self.metrics.append(metric)
|
self.metrics.append(metric)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def short_data(self) -> Dict[str, Any]:
|
def short_data(self) -> dict[str, Any]:
|
||||||
"""Data representation of the datasource sent to the frontend"""
|
"""Data representation of the datasource sent to the frontend"""
|
||||||
return {
|
return {
|
||||||
"edit_url": self.url,
|
"edit_url": self.url,
|
||||||
|
@ -249,11 +240,11 @@ class BaseDatasource(
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def select_star(self) -> Optional[str]:
|
def select_star(self) -> str | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def order_by_choices(self) -> List[Tuple[str, str]]:
|
def order_by_choices(self) -> list[tuple[str, str]]:
|
||||||
choices = []
|
choices = []
|
||||||
# self.column_names return sorted column_names
|
# self.column_names return sorted column_names
|
||||||
for column_name in self.column_names:
|
for column_name in self.column_names:
|
||||||
|
@ -267,7 +258,7 @@ class BaseDatasource(
|
||||||
return choices
|
return choices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def verbose_map(self) -> Dict[str, str]:
|
def verbose_map(self) -> dict[str, str]:
|
||||||
verb_map = {"__timestamp": "Time"}
|
verb_map = {"__timestamp": "Time"}
|
||||||
verb_map.update(
|
verb_map.update(
|
||||||
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
|
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
|
||||||
|
@ -278,7 +269,7 @@ class BaseDatasource(
|
||||||
return verb_map
|
return verb_map
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
"""Data representation of the datasource sent to the frontend"""
|
"""Data representation of the datasource sent to the frontend"""
|
||||||
return {
|
return {
|
||||||
# simple fields
|
# simple fields
|
||||||
|
@ -313,8 +304,8 @@ class BaseDatasource(
|
||||||
}
|
}
|
||||||
|
|
||||||
def data_for_slices( # pylint: disable=too-many-locals
|
def data_for_slices( # pylint: disable=too-many-locals
|
||||||
self, slices: List[Slice]
|
self, slices: list[Slice]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
The representation of the datasource containing only the required data
|
The representation of the datasource containing only the required data
|
||||||
to render the provided slices.
|
to render the provided slices.
|
||||||
|
@ -381,8 +372,8 @@ class BaseDatasource(
|
||||||
if metric["metric_name"] in metric_names
|
if metric["metric_name"] in metric_names
|
||||||
]
|
]
|
||||||
|
|
||||||
filtered_columns: List[Column] = []
|
filtered_columns: list[Column] = []
|
||||||
column_types: Set[GenericDataType] = set()
|
column_types: set[GenericDataType] = set()
|
||||||
for column in data["columns"]:
|
for column in data["columns"]:
|
||||||
generic_type = column.get("type_generic")
|
generic_type = column.get("type_generic")
|
||||||
if generic_type is not None:
|
if generic_type is not None:
|
||||||
|
@ -413,18 +404,18 @@ class BaseDatasource(
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_values_handler( # pylint: disable=too-many-arguments
|
def filter_values_handler( # pylint: disable=too-many-arguments
|
||||||
values: Optional[FilterValues],
|
values: FilterValues | None,
|
||||||
operator: str,
|
operator: str,
|
||||||
target_generic_type: GenericDataType,
|
target_generic_type: GenericDataType,
|
||||||
target_native_type: Optional[str] = None,
|
target_native_type: str | None = None,
|
||||||
is_list_target: bool = False,
|
is_list_target: bool = False,
|
||||||
db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
|
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
|
||||||
db_extra: Optional[Dict[str, Any]] = None,
|
db_extra: dict[str, Any] | None = None,
|
||||||
) -> Optional[FilterValues]:
|
) -> FilterValues | None:
|
||||||
if values is None:
|
if values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
|
def handle_single_value(value: FilterValue | None) -> FilterValue | None:
|
||||||
if operator == utils.FilterOperator.TEMPORAL_RANGE:
|
if operator == utils.FilterOperator.TEMPORAL_RANGE:
|
||||||
return value
|
return value
|
||||||
if (
|
if (
|
||||||
|
@ -464,7 +455,7 @@ class BaseDatasource(
|
||||||
values = values[0] if values else None
|
values = values[0] if values else None
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def external_metadata(self) -> List[Dict[str, str]]:
|
def external_metadata(self) -> list[dict[str, str]]:
|
||||||
"""Returns column information from the external system"""
|
"""Returns column information from the external system"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -483,7 +474,7 @@ class BaseDatasource(
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||||
"""Given a column, returns an iterable of distinct values
|
"""Given a column, returns an iterable of distinct values
|
||||||
|
|
||||||
This is used to populate the dropdown showing a list of
|
This is used to populate the dropdown showing a list of
|
||||||
|
@ -494,7 +485,7 @@ class BaseDatasource(
|
||||||
def default_query(qry: Query) -> Query:
|
def default_query(qry: Query) -> Query:
|
||||||
return qry
|
return qry
|
||||||
|
|
||||||
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
|
def get_column(self, column_name: str | None) -> BaseColumn | None:
|
||||||
if not column_name:
|
if not column_name:
|
||||||
return None
|
return None
|
||||||
for col in self.columns:
|
for col in self.columns:
|
||||||
|
@ -504,11 +495,11 @@ class BaseDatasource(
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fk_many_from_list(
|
def get_fk_many_from_list(
|
||||||
object_list: List[Any],
|
object_list: list[Any],
|
||||||
fkmany: List[Column],
|
fkmany: list[Column],
|
||||||
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
|
fkmany_class: builtins.type[BaseColumn | BaseMetric],
|
||||||
key_attr: str,
|
key_attr: str,
|
||||||
) -> List[Column]:
|
) -> list[Column]:
|
||||||
"""Update ORM one-to-many list from object list
|
"""Update ORM one-to-many list from object list
|
||||||
|
|
||||||
Used for syncing metrics and columns using the same code"""
|
Used for syncing metrics and columns using the same code"""
|
||||||
|
@ -541,7 +532,7 @@ class BaseDatasource(
|
||||||
fkmany += new_fks
|
fkmany += new_fks
|
||||||
return fkmany
|
return fkmany
|
||||||
|
|
||||||
def update_from_object(self, obj: Dict[str, Any]) -> None:
|
def update_from_object(self, obj: dict[str, Any]) -> None:
|
||||||
"""Update datasource from a data structure
|
"""Update datasource from a data structure
|
||||||
|
|
||||||
The UI's table editor crafts a complex data structure that
|
The UI's table editor crafts a complex data structure that
|
||||||
|
@ -578,7 +569,7 @@ class BaseDatasource(
|
||||||
|
|
||||||
def get_extra_cache_keys( # pylint: disable=no-self-use
|
def get_extra_cache_keys( # pylint: disable=no-self-use
|
||||||
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
|
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
|
||||||
) -> List[Hashable]:
|
) -> list[Hashable]:
|
||||||
"""If a datasource needs to provide additional keys for calculation of
|
"""If a datasource needs to provide additional keys for calculation of
|
||||||
cache keys, those can be provided via this method
|
cache keys, those can be provided via this method
|
||||||
|
|
||||||
|
@ -607,14 +598,14 @@ class BaseDatasource(
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_datasource_by_name(
|
def get_datasource_by_name(
|
||||||
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
||||||
) -> Optional["BaseDatasource"]:
|
) -> BaseDatasource | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||||
"""Interface for column"""
|
"""Interface for column"""
|
||||||
|
|
||||||
__tablename__: Optional[str] = None # {connector_name}_column
|
__tablename__: str | None = None # {connector_name}_column
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
column_name = Column(String(255), nullable=False)
|
column_name = Column(String(255), nullable=False)
|
||||||
|
@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||||
is_dttm = None
|
is_dttm = None
|
||||||
|
|
||||||
# [optional] Set this to support import/export functionality
|
# [optional] Set this to support import/export functionality
|
||||||
export_fields: List[Any] = []
|
export_fields: list[Any] = []
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(self.column_name)
|
return str(self.column_name)
|
||||||
|
@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||||
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
|
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_generic(self) -> Optional[utils.GenericDataType]:
|
def type_generic(self) -> utils.GenericDataType | None:
|
||||||
if self.is_string:
|
if self.is_string:
|
||||||
return utils.GenericDataType.STRING
|
return utils.GenericDataType.STRING
|
||||||
if self.is_boolean:
|
if self.is_boolean:
|
||||||
|
@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
attrs = (
|
attrs = (
|
||||||
"id",
|
"id",
|
||||||
"column_name",
|
"column_name",
|
||||||
|
@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
|
||||||
class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
||||||
"""Interface for Metrics"""
|
"""Interface for Metrics"""
|
||||||
|
|
||||||
__tablename__: Optional[str] = None # {connector_name}_metric
|
__tablename__: str | None = None # {connector_name}_metric
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
metric_name = Column(String(255), nullable=False)
|
metric_name = Column(String(255), nullable=False)
|
||||||
|
@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def perm(self) -> Optional[str]:
|
def perm(self) -> str | None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
attrs = (
|
attrs = (
|
||||||
"id",
|
"id",
|
||||||
"metric_name",
|
"metric_name",
|
||||||
|
|
|
@ -22,21 +22,10 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Hashable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import (
|
from typing import Any, Callable, cast
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
Hashable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import dateutil.parser
|
import dateutil.parser
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MetadataResult:
|
class MetadataResult:
|
||||||
added: List[str] = field(default_factory=list)
|
added: list[str] = field(default_factory=list)
|
||||||
removed: List[str] = field(default_factory=list)
|
removed: list[str] = field(default_factory=list)
|
||||||
modified: List[str] = field(default_factory=list)
|
modified: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AnnotationDatasource(BaseDatasource):
|
class AnnotationDatasource(BaseDatasource):
|
||||||
|
@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource):
|
||||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
||||||
__tablename__ = "table_columns"
|
__tablename__ = "table_columns"
|
||||||
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
|
||||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||||
table: Mapped["SqlaTable"] = relationship(
|
table: Mapped[SqlaTable] = relationship(
|
||||||
"SqlaTable",
|
"SqlaTable",
|
||||||
back_populates="columns",
|
back_populates="columns",
|
||||||
)
|
)
|
||||||
|
@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
||||||
return self.type_generic == GenericDataType.TEMPORAL
|
return self.type_generic == GenericDataType.TEMPORAL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_engine_spec(self) -> Type[BaseEngineSpec]:
|
def db_engine_spec(self) -> type[BaseEngineSpec]:
|
||||||
return self.table.db_engine_spec
|
return self.table.db_engine_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_extra(self) -> Dict[str, Any]:
|
def db_extra(self) -> dict[str, Any]:
|
||||||
return self.table.database.get_extra()
|
return self.table.database.get_extra()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_generic(self) -> Optional[utils.GenericDataType]:
|
def type_generic(self) -> utils.GenericDataType | None:
|
||||||
if self.is_dttm:
|
if self.is_dttm:
|
||||||
return GenericDataType.TEMPORAL
|
return GenericDataType.TEMPORAL
|
||||||
|
|
||||||
|
@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
||||||
|
|
||||||
def get_sqla_col(
|
def get_sqla_col(
|
||||||
self,
|
self,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> Column:
|
) -> Column:
|
||||||
label = label or self.column_name
|
label = label or self.column_name
|
||||||
db_engine_spec = self.db_engine_spec
|
db_engine_spec = self.db_engine_spec
|
||||||
|
@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
||||||
|
|
||||||
def get_timestamp_expression(
|
def get_timestamp_expression(
|
||||||
self,
|
self,
|
||||||
time_grain: Optional[str],
|
time_grain: str | None,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> Union[TimestampExpression, Label]:
|
) -> TimestampExpression | Label:
|
||||||
"""
|
"""
|
||||||
Return a SQLAlchemy Core element representation of self to be used in a query.
|
Return a SQLAlchemy Core element representation of self to be used in a query.
|
||||||
|
|
||||||
|
@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
|
||||||
return self.table.make_sqla_column_compatible(time_expr, label)
|
return self.table.make_sqla_column_compatible(time_expr, label)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
attrs = (
|
attrs = (
|
||||||
"id",
|
"id",
|
||||||
"column_name",
|
"column_name",
|
||||||
|
@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
||||||
__tablename__ = "sql_metrics"
|
__tablename__ = "sql_metrics"
|
||||||
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
|
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
|
||||||
table_id = Column(Integer, ForeignKey("tables.id"))
|
table_id = Column(Integer, ForeignKey("tables.id"))
|
||||||
table: Mapped["SqlaTable"] = relationship(
|
table: Mapped[SqlaTable] = relationship(
|
||||||
"SqlaTable",
|
"SqlaTable",
|
||||||
back_populates="metrics",
|
back_populates="metrics",
|
||||||
)
|
)
|
||||||
|
@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
||||||
|
|
||||||
def get_sqla_col(
|
def get_sqla_col(
|
||||||
self,
|
self,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> Column:
|
) -> Column:
|
||||||
label = label or self.metric_name
|
label = label or self.metric_name
|
||||||
expression = self.expression
|
expression = self.expression
|
||||||
|
@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
||||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def perm(self) -> Optional[str]:
|
def perm(self) -> str | None:
|
||||||
return (
|
return (
|
||||||
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
|
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
|
||||||
obj=self, parent_name=self.table.full_name
|
obj=self, parent_name=self.table.full_name
|
||||||
|
@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_perm(self) -> Optional[str]:
|
def get_perm(self) -> str | None:
|
||||||
return self.perm
|
return self.perm
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
attrs = (
|
attrs = (
|
||||||
"is_certified",
|
"is_certified",
|
||||||
"certified_by",
|
"certified_by",
|
||||||
|
@ -473,11 +462,11 @@ sqlatable_user = Table(
|
||||||
|
|
||||||
|
|
||||||
def _process_sql_expression(
|
def _process_sql_expression(
|
||||||
expression: Optional[str],
|
expression: str | None,
|
||||||
database_id: int,
|
database_id: int,
|
||||||
schema: str,
|
schema: str,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
if template_processor and expression:
|
if template_processor and expression:
|
||||||
expression = template_processor.process_template(expression)
|
expression = template_processor.process_template(expression)
|
||||||
if expression:
|
if expression:
|
||||||
|
@ -501,12 +490,12 @@ class SqlaTable(
|
||||||
type = "table"
|
type = "table"
|
||||||
query_language = "sql"
|
query_language = "sql"
|
||||||
is_rls_supported = True
|
is_rls_supported = True
|
||||||
columns: Mapped[List[TableColumn]] = relationship(
|
columns: Mapped[list[TableColumn]] = relationship(
|
||||||
TableColumn,
|
TableColumn,
|
||||||
back_populates="table",
|
back_populates="table",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
metrics: Mapped[List[SqlMetric]] = relationship(
|
metrics: Mapped[list[SqlMetric]] = relationship(
|
||||||
SqlMetric,
|
SqlMetric,
|
||||||
back_populates="table",
|
back_populates="table",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
|
@ -577,11 +566,11 @@ class SqlaTable(
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_extra(self) -> Dict[str, Any]:
|
def db_extra(self) -> dict[str, Any]:
|
||||||
return self.database.get_extra()
|
return self.database.get_extra()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_cte(sql: str, cte: Optional[str]) -> str:
|
def _apply_cte(sql: str, cte: str | None) -> str:
|
||||||
"""
|
"""
|
||||||
Append a CTE before the SELECT statement if defined
|
Append a CTE before the SELECT statement if defined
|
||||||
|
|
||||||
|
@ -594,7 +583,7 @@ class SqlaTable(
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_engine_spec(self) -> Type[BaseEngineSpec]:
|
def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
|
||||||
return self.database.db_engine_spec
|
return self.database.db_engine_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -637,9 +626,9 @@ class SqlaTable(
|
||||||
cls,
|
cls,
|
||||||
session: Session,
|
session: Session,
|
||||||
datasource_name: str,
|
datasource_name: str,
|
||||||
schema: Optional[str],
|
schema: str | None,
|
||||||
database_name: str,
|
database_name: str,
|
||||||
) -> Optional[SqlaTable]:
|
) -> SqlaTable | None:
|
||||||
schema = schema or None
|
schema = schema or None
|
||||||
query = (
|
query = (
|
||||||
session.query(cls)
|
session.query(cls)
|
||||||
|
@ -660,7 +649,7 @@ class SqlaTable(
|
||||||
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
|
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
|
||||||
return Markup(anchor)
|
return Markup(anchor)
|
||||||
|
|
||||||
def get_schema_perm(self) -> Optional[str]:
|
def get_schema_perm(self) -> str | None:
|
||||||
"""Returns schema permission if present, database one otherwise."""
|
"""Returns schema permission if present, database one otherwise."""
|
||||||
return security_manager.get_schema_perm(self.database, self.schema)
|
return security_manager.get_schema_perm(self.database, self.schema)
|
||||||
|
|
||||||
|
@ -685,18 +674,18 @@ class SqlaTable(
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dttm_cols(self) -> List[str]:
|
def dttm_cols(self) -> list[str]:
|
||||||
l = [c.column_name for c in self.columns if c.is_dttm]
|
l = [c.column_name for c in self.columns if c.is_dttm]
|
||||||
if self.main_dttm_col and self.main_dttm_col not in l:
|
if self.main_dttm_col and self.main_dttm_col not in l:
|
||||||
l.append(self.main_dttm_col)
|
l.append(self.main_dttm_col)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_cols(self) -> List[str]:
|
def num_cols(self) -> list[str]:
|
||||||
return [c.column_name for c in self.columns if c.is_numeric]
|
return [c.column_name for c in self.columns if c.is_numeric]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def any_dttm_col(self) -> Optional[str]:
|
def any_dttm_col(self) -> str | None:
|
||||||
cols = self.dttm_cols
|
cols = self.dttm_cols
|
||||||
return cols[0] if cols else None
|
return cols[0] if cols else None
|
||||||
|
|
||||||
|
@ -713,7 +702,7 @@ class SqlaTable(
|
||||||
def sql_url(self) -> str:
|
def sql_url(self) -> str:
|
||||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||||
|
|
||||||
def external_metadata(self) -> List[Dict[str, str]]:
|
def external_metadata(self) -> list[dict[str, str]]:
|
||||||
# todo(yongjie): create a physical table column type in a separate PR
|
# todo(yongjie): create a physical table column type in a separate PR
|
||||||
if self.sql:
|
if self.sql:
|
||||||
return get_virtual_table_metadata(dataset=self) # type: ignore
|
return get_virtual_table_metadata(dataset=self) # type: ignore
|
||||||
|
@ -724,14 +713,14 @@ class SqlaTable(
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def time_column_grains(self) -> Dict[str, Any]:
|
def time_column_grains(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"time_columns": self.dttm_cols,
|
"time_columns": self.dttm_cols,
|
||||||
"time_grains": [grain.name for grain in self.database.grains()],
|
"time_grains": [grain.name for grain in self.database.grains()],
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def select_star(self) -> Optional[str]:
|
def select_star(self) -> str | None:
|
||||||
# show_cols and latest_partition set to false to avoid
|
# show_cols and latest_partition set to false to avoid
|
||||||
# the expensive cost of inspecting the DB
|
# the expensive cost of inspecting the DB
|
||||||
return self.database.select_star(
|
return self.database.select_star(
|
||||||
|
@ -739,20 +728,20 @@ class SqlaTable(
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def health_check_message(self) -> Optional[str]:
|
def health_check_message(self) -> str | None:
|
||||||
check = config["DATASET_HEALTH_CHECK"]
|
check = config["DATASET_HEALTH_CHECK"]
|
||||||
return check(self) if check else None
|
return check(self) if check else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def granularity_sqla(self) -> List[Tuple[Any, Any]]:
|
def granularity_sqla(self) -> list[tuple[Any, Any]]:
|
||||||
return utils.choicify(self.dttm_cols)
|
return utils.choicify(self.dttm_cols)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def time_grain_sqla(self) -> List[Tuple[Any, Any]]:
|
def time_grain_sqla(self) -> list[tuple[Any, Any]]:
|
||||||
return [(g.duration, g.name) for g in self.database.grains() or []]
|
return [(g.duration, g.name) for g in self.database.grains() or []]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
data_ = super().data
|
data_ = super().data
|
||||||
if self.type == "table":
|
if self.type == "table":
|
||||||
data_["granularity_sqla"] = self.granularity_sqla
|
data_["granularity_sqla"] = self.granularity_sqla
|
||||||
|
@ -767,7 +756,7 @@ class SqlaTable(
|
||||||
return data_
|
return data_
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_dict(self) -> Dict[str, Any]:
|
def extra_dict(self) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
return json.loads(self.extra)
|
return json.loads(self.extra)
|
||||||
except (TypeError, json.JSONDecodeError):
|
except (TypeError, json.JSONDecodeError):
|
||||||
|
@ -775,7 +764,7 @@ class SqlaTable(
|
||||||
|
|
||||||
def get_fetch_values_predicate(
|
def get_fetch_values_predicate(
|
||||||
self,
|
self,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> TextClause:
|
) -> TextClause:
|
||||||
fetch_values_predicate = self.fetch_values_predicate
|
fetch_values_predicate = self.fetch_values_predicate
|
||||||
if template_processor:
|
if template_processor:
|
||||||
|
@ -792,7 +781,7 @@ class SqlaTable(
|
||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
|
|
||||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
|
||||||
"""Runs query against sqla to retrieve some
|
"""Runs query against sqla to retrieve some
|
||||||
sample values for the given column.
|
sample values for the given column.
|
||||||
"""
|
"""
|
||||||
|
@ -869,8 +858,8 @@ class SqlaTable(
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
def get_from_clause(
|
def get_from_clause(
|
||||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
self, template_processor: BaseTemplateProcessor | None = None
|
||||||
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
|
) -> tuple[TableClause | Alias, str | None]:
|
||||||
"""
|
"""
|
||||||
Return where to select the columns and metrics from. Either a physical table
|
Return where to select the columns and metrics from. Either a physical table
|
||||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||||
|
@ -899,7 +888,7 @@ class SqlaTable(
|
||||||
return from_clause, cte
|
return from_clause, cte
|
||||||
|
|
||||||
def get_rendered_sql(
|
def get_rendered_sql(
|
||||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
self, template_processor: BaseTemplateProcessor | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Render sql with template engine (Jinja).
|
Render sql with template engine (Jinja).
|
||||||
|
@ -928,8 +917,8 @@ class SqlaTable(
|
||||||
def adhoc_metric_to_sqla(
|
def adhoc_metric_to_sqla(
|
||||||
self,
|
self,
|
||||||
metric: AdhocMetric,
|
metric: AdhocMetric,
|
||||||
columns_by_name: Dict[str, TableColumn],
|
columns_by_name: dict[str, TableColumn],
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> ColumnElement:
|
) -> ColumnElement:
|
||||||
"""
|
"""
|
||||||
Turn an adhoc metric into a sqlalchemy column.
|
Turn an adhoc metric into a sqlalchemy column.
|
||||||
|
@ -946,7 +935,7 @@ class SqlaTable(
|
||||||
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
|
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
|
||||||
metric_column = metric.get("column") or {}
|
metric_column = metric.get("column") or {}
|
||||||
column_name = cast(str, metric_column.get("column_name"))
|
column_name = cast(str, metric_column.get("column_name"))
|
||||||
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
|
table_column: TableColumn | None = columns_by_name.get(column_name)
|
||||||
if table_column:
|
if table_column:
|
||||||
sqla_column = table_column.get_sqla_col(
|
sqla_column = table_column.get_sqla_col(
|
||||||
template_processor=template_processor
|
template_processor=template_processor
|
||||||
|
@ -971,7 +960,7 @@ class SqlaTable(
|
||||||
self,
|
self,
|
||||||
col: AdhocColumn,
|
col: AdhocColumn,
|
||||||
force_type_check: bool = False,
|
force_type_check: bool = False,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> ColumnElement:
|
) -> ColumnElement:
|
||||||
"""
|
"""
|
||||||
Turn an adhoc column into a sqlalchemy column.
|
Turn an adhoc column into a sqlalchemy column.
|
||||||
|
@ -1021,7 +1010,7 @@ class SqlaTable(
|
||||||
return self.make_sqla_column_compatible(sqla_column, label)
|
return self.make_sqla_column_compatible(sqla_column, label)
|
||||||
|
|
||||||
def make_sqla_column_compatible(
|
def make_sqla_column_compatible(
|
||||||
self, sqla_col: ColumnElement, label: Optional[str] = None
|
self, sqla_col: ColumnElement, label: str | None = None
|
||||||
) -> ColumnElement:
|
) -> ColumnElement:
|
||||||
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
||||||
:param sqla_col: sqlalchemy column instance
|
:param sqla_col: sqlalchemy column instance
|
||||||
|
@ -1038,7 +1027,7 @@ class SqlaTable(
|
||||||
return sqla_col
|
return sqla_col
|
||||||
|
|
||||||
def make_orderby_compatible(
|
def make_orderby_compatible(
|
||||||
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
|
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
If needed, make sure aliases for selected columns are not used in
|
If needed, make sure aliases for selected columns are not used in
|
||||||
|
@ -1069,7 +1058,7 @@ class SqlaTable(
|
||||||
def get_sqla_row_level_filters(
|
def get_sqla_row_level_filters(
|
||||||
self,
|
self,
|
||||||
template_processor: BaseTemplateProcessor,
|
template_processor: BaseTemplateProcessor,
|
||||||
) -> List[TextClause]:
|
) -> list[TextClause]:
|
||||||
"""
|
"""
|
||||||
Return the appropriate row level security filters for this table and the
|
Return the appropriate row level security filters for this table and the
|
||||||
current user. A custom username can be passed when the user is not present in the
|
current user. A custom username can be passed when the user is not present in the
|
||||||
|
@ -1078,8 +1067,8 @@ class SqlaTable(
|
||||||
:param template_processor: The template processor to apply to the filters.
|
:param template_processor: The template processor to apply to the filters.
|
||||||
:returns: A list of SQL clauses to be ANDed together.
|
:returns: A list of SQL clauses to be ANDed together.
|
||||||
"""
|
"""
|
||||||
all_filters: List[TextClause] = []
|
all_filters: list[TextClause] = []
|
||||||
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
|
filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
|
||||||
try:
|
try:
|
||||||
for filter_ in security_manager.get_rls_filters(self):
|
for filter_ in security_manager.get_rls_filters(self):
|
||||||
clause = self.text(
|
clause = self.text(
|
||||||
|
@ -1114,9 +1103,9 @@ class SqlaTable(
|
||||||
def _get_series_orderby(
|
def _get_series_orderby(
|
||||||
self,
|
self,
|
||||||
series_limit_metric: Metric,
|
series_limit_metric: Metric,
|
||||||
metrics_by_name: Dict[str, SqlMetric],
|
metrics_by_name: dict[str, SqlMetric],
|
||||||
columns_by_name: Dict[str, TableColumn],
|
columns_by_name: dict[str, TableColumn],
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> Column:
|
) -> Column:
|
||||||
if utils.is_adhoc_metric(series_limit_metric):
|
if utils.is_adhoc_metric(series_limit_metric):
|
||||||
assert isinstance(series_limit_metric, dict)
|
assert isinstance(series_limit_metric, dict)
|
||||||
|
@ -1138,8 +1127,8 @@ class SqlaTable(
|
||||||
self,
|
self,
|
||||||
row: pd.Series,
|
row: pd.Series,
|
||||||
dimension: str,
|
dimension: str,
|
||||||
columns_by_name: Dict[str, TableColumn],
|
columns_by_name: dict[str, TableColumn],
|
||||||
) -> Union[str, int, float, bool, Text]:
|
) -> str | int | float | bool | Text:
|
||||||
"""
|
"""
|
||||||
Convert a prequery result type to its equivalent Python type.
|
Convert a prequery result type to its equivalent Python type.
|
||||||
|
|
||||||
|
@ -1159,7 +1148,7 @@ class SqlaTable(
|
||||||
value = value.item()
|
value = value.item()
|
||||||
|
|
||||||
column_ = columns_by_name[dimension]
|
column_ = columns_by_name[dimension]
|
||||||
db_extra: Dict[str, Any] = self.database.get_extra()
|
db_extra: dict[str, Any] = self.database.get_extra()
|
||||||
|
|
||||||
if column_.type and column_.is_temporal and isinstance(value, str):
|
if column_.type and column_.is_temporal and isinstance(value, str):
|
||||||
sql = self.db_engine_spec.convert_dttm(
|
sql = self.db_engine_spec.convert_dttm(
|
||||||
|
@ -1174,9 +1163,9 @@ class SqlaTable(
|
||||||
def _get_top_groups(
|
def _get_top_groups(
|
||||||
self,
|
self,
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
dimensions: List[str],
|
dimensions: list[str],
|
||||||
groupby_exprs: Dict[str, Any],
|
groupby_exprs: dict[str, Any],
|
||||||
columns_by_name: Dict[str, TableColumn],
|
columns_by_name: dict[str, TableColumn],
|
||||||
) -> ColumnElement:
|
) -> ColumnElement:
|
||||||
groups = []
|
groups = []
|
||||||
for _unused, row in df.iterrows():
|
for _unused, row in df.iterrows():
|
||||||
|
@ -1201,7 +1190,7 @@ class SqlaTable(
|
||||||
errors = None
|
errors = None
|
||||||
error_message = None
|
error_message = None
|
||||||
|
|
||||||
def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
|
||||||
"""
|
"""
|
||||||
Some engines change the case or generate bespoke column names, either by
|
Some engines change the case or generate bespoke column names, either by
|
||||||
default or due to lack of support for aliasing. This function ensures that
|
default or due to lack of support for aliasing. This function ensures that
|
||||||
|
@ -1283,7 +1272,7 @@ class SqlaTable(
|
||||||
else self.columns
|
else self.columns
|
||||||
)
|
)
|
||||||
|
|
||||||
old_columns_by_name: Dict[str, TableColumn] = {
|
old_columns_by_name: dict[str, TableColumn] = {
|
||||||
col.column_name: col for col in old_columns
|
col.column_name: col for col in old_columns
|
||||||
}
|
}
|
||||||
results = MetadataResult(
|
results = MetadataResult(
|
||||||
|
@ -1341,8 +1330,8 @@ class SqlaTable(
|
||||||
session: Session,
|
session: Session,
|
||||||
database: Database,
|
database: Database,
|
||||||
datasource_name: str,
|
datasource_name: str,
|
||||||
schema: Optional[str] = None,
|
schema: str | None = None,
|
||||||
) -> List[SqlaTable]:
|
) -> list[SqlaTable]:
|
||||||
query = (
|
query = (
|
||||||
session.query(cls)
|
session.query(cls)
|
||||||
.filter_by(database_id=database.id)
|
.filter_by(database_id=database.id)
|
||||||
|
@ -1357,9 +1346,9 @@ class SqlaTable(
|
||||||
cls,
|
cls,
|
||||||
session: Session,
|
session: Session,
|
||||||
database: Database,
|
database: Database,
|
||||||
permissions: Set[str],
|
permissions: set[str],
|
||||||
schema_perms: Set[str],
|
schema_perms: set[str],
|
||||||
) -> List[SqlaTable]:
|
) -> list[SqlaTable]:
|
||||||
# TODO(hughhhh): add unit test
|
# TODO(hughhhh): add unit test
|
||||||
return (
|
return (
|
||||||
session.query(cls)
|
session.query(cls)
|
||||||
|
@ -1389,7 +1378,7 @@ class SqlaTable(
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
|
def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
|
||||||
qry = session.query(cls)
|
qry = session.query(cls)
|
||||||
qry = cls.default_query(qry)
|
qry = cls.default_query(qry)
|
||||||
return qry.all()
|
return qry.all()
|
||||||
|
@ -1409,7 +1398,7 @@ class SqlaTable(
|
||||||
:param query_obj: query object to analyze
|
:param query_obj: query object to analyze
|
||||||
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
|
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
|
||||||
"""
|
"""
|
||||||
templatable_statements: List[str] = []
|
templatable_statements: list[str] = []
|
||||||
if self.sql:
|
if self.sql:
|
||||||
templatable_statements.append(self.sql)
|
templatable_statements.append(self.sql)
|
||||||
if self.fetch_values_predicate:
|
if self.fetch_values_predicate:
|
||||||
|
@ -1428,7 +1417,7 @@ class SqlaTable(
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
|
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||||
"""
|
"""
|
||||||
The cache key of a SqlaTable needs to consider any keys added by the parent
|
The cache key of a SqlaTable needs to consider any keys added by the parent
|
||||||
class and any keys added via `ExtraCache`.
|
class and any keys added via `ExtraCache`.
|
||||||
|
@ -1489,7 +1478,7 @@ class SqlaTable(
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_column( # pylint: disable=unused-argument
|
def update_column( # pylint: disable=unused-argument
|
||||||
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
|
mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param mapper: Unused.
|
:param mapper: Unused.
|
||||||
|
|
|
@ -17,19 +17,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import (
|
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
@ -58,8 +48,8 @@ if TYPE_CHECKING:
|
||||||
def get_physical_table_metadata(
|
def get_physical_table_metadata(
|
||||||
database: Database,
|
database: Database,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
schema_name: Optional[str] = None,
|
schema_name: str | None = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Use SQLAlchemy inspector to get table metadata"""
|
"""Use SQLAlchemy inspector to get table metadata"""
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
db_dialect = database.get_dialect()
|
db_dialect = database.get_dialect()
|
||||||
|
@ -103,7 +93,7 @@ def get_physical_table_metadata(
|
||||||
return cols
|
return cols
|
||||||
|
|
||||||
|
|
||||||
def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
|
||||||
"""Use SQLparser to get virtual dataset metadata"""
|
"""Use SQLparser to get virtual dataset metadata"""
|
||||||
if not dataset.sql:
|
if not dataset.sql:
|
||||||
raise SupersetGenericDBErrorException(
|
raise SupersetGenericDBErrorException(
|
||||||
|
@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
|
||||||
def get_columns_description(
|
def get_columns_description(
|
||||||
database: Database,
|
database: Database,
|
||||||
query: str,
|
query: str,
|
||||||
) -> List[ResultSetColumnType]:
|
) -> list[ResultSetColumnType]:
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
try:
|
try:
|
||||||
with database.get_raw_connection() as conn:
|
with database.get_raw_connection() as conn:
|
||||||
|
@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
|
||||||
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
|
def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]:
|
||||||
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
|
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,9 +171,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def find_cached_objects_in_session(
|
def find_cached_objects_in_session(
|
||||||
session: Session,
|
session: Session,
|
||||||
cls: Type[DeclarativeModel],
|
cls: type[DeclarativeModel],
|
||||||
ids: Optional[Iterable[int]] = None,
|
ids: Iterable[int] | None = None,
|
||||||
uuids: Optional[Iterable[UUID]] = None,
|
uuids: Iterable[UUID] | None = None,
|
||||||
) -> Iterator[DeclarativeModel]:
|
) -> Iterator[DeclarativeModel]:
|
||||||
"""Find known ORM instances in cached SQLA session states.
|
"""Find known ORM instances in cached SQLA session states.
|
||||||
|
|
||||||
|
|
|
@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
|
||||||
resp = super().edit(pk)
|
resp = super().edit(pk)
|
||||||
if isinstance(resp, str):
|
if isinstance(resp, str):
|
||||||
return resp
|
return resp
|
||||||
return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk))
|
return redirect(f"/explore/?datasource_type=table&datasource_id={pk}")
|
||||||
|
|
||||||
@expose("/list/")
|
@expose("/list/")
|
||||||
@has_access
|
@has_access
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.css_templates.commands.exceptions import (
|
from superset.css_templates.commands.exceptions import (
|
||||||
|
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkDeleteCssTemplateCommand(BaseCommand):
|
class BulkDeleteCssTemplateCommand(BaseCommand):
|
||||||
def __init__(self, model_ids: List[int]):
|
def __init__(self, model_ids: list[int]):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[CssTemplate]] = None
|
self._models: Optional[list[CssTemplate]] = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO):
|
||||||
model_cls = CssTemplate
|
model_cls = CssTemplate
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None:
|
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
|
||||||
item_ids = [model.id for model in models] if models else []
|
item_ids = [model.id for model in models] if models else []
|
||||||
try:
|
try:
|
||||||
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(
|
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
# pylint: disable=isinstance-second-argument-not-valid-type
|
# pylint: disable=isinstance-second-argument-not-valid-type
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from flask_appbuilder.models.filters import BaseFilter
|
from flask_appbuilder.models.filters import BaseFilter
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
@ -37,7 +37,7 @@ class BaseDAO:
|
||||||
Base DAO, implement base CRUD sqlalchemy operations
|
Base DAO, implement base CRUD sqlalchemy operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_cls: Optional[Type[Model]] = None
|
model_cls: Optional[type[Model]] = None
|
||||||
"""
|
"""
|
||||||
Child classes need to state the Model class so they don't need to implement basic
|
Child classes need to state the Model class so they don't need to implement basic
|
||||||
create, update and delete methods
|
create, update and delete methods
|
||||||
|
@ -75,10 +75,10 @@ class BaseDAO:
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_ids(
|
def find_by_ids(
|
||||||
cls,
|
cls,
|
||||||
model_ids: Union[List[str], List[int]],
|
model_ids: Union[list[str], list[int]],
|
||||||
session: Session = None,
|
session: Session = None,
|
||||||
skip_base_filter: bool = False,
|
skip_base_filter: bool = False,
|
||||||
) -> List[Model]:
|
) -> list[Model]:
|
||||||
"""
|
"""
|
||||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||||
"""
|
"""
|
||||||
|
@ -95,7 +95,7 @@ class BaseDAO:
|
||||||
return query.all()
|
return query.all()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_all(cls) -> List[Model]:
|
def find_all(cls) -> list[Model]:
|
||||||
"""
|
"""
|
||||||
Get all that fit the `base_filter`
|
Get all that fit the `base_filter`
|
||||||
"""
|
"""
|
||||||
|
@ -121,7 +121,7 @@ class BaseDAO:
|
||||||
return query.filter_by(**filter_by).one_or_none()
|
return query.filter_by(**filter_by).one_or_none()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic for creating models
|
Generic for creating models
|
||||||
:raises: DAOCreateFailedError
|
:raises: DAOCreateFailedError
|
||||||
|
@ -163,7 +163,7 @@ class BaseDAO:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update(
|
def update(
|
||||||
cls, model: Model, properties: Dict[str, Any], commit: bool = True
|
cls, model: Model, properties: dict[str, Any], commit: bool = True
|
||||||
) -> Model:
|
) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic update a model
|
Generic update a model
|
||||||
|
@ -196,7 +196,7 @@ class BaseDAO:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def bulk_delete(cls, models: List[Model], commit: bool = True) -> None:
|
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
|
||||||
try:
|
try:
|
||||||
for model in models:
|
for model in models:
|
||||||
cls.delete(model, False)
|
cls.delete(model, False)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkDeleteDashboardCommand(BaseCommand):
|
class BulkDeleteDashboardCommand(BaseCommand):
|
||||||
def __init__(self, model_ids: List[int]):
|
def __init__(self, model_ids: list[int]):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[Dashboard]] = None
|
self._models: Optional[list[Dashboard]] = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateDashboardCommand(CreateMixin, BaseCommand):
|
class CreateDashboardCommand(CreateMixin, BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
|
@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
owner_ids: Optional[List[int]] = self._properties.get("owners")
|
owner_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
role_ids: Optional[List[int]] = self._properties.get("roles")
|
role_ids: Optional[list[int]] = self._properties.get("roles")
|
||||||
slug: str = self._properties.get("slug", "")
|
slug: str = self._properties.get("slug", "")
|
||||||
|
|
||||||
# Validate slug uniqueness
|
# Validate slug uniqueness
|
||||||
|
|
|
@ -20,7 +20,8 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, Iterator, Optional, Set, Tuple
|
from typing import Any, Optional
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_position(title: str) -> Dict[str, Any]:
|
def get_default_position(title: str) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"DASHBOARD_VERSION_KEY": "v2",
|
"DASHBOARD_VERSION_KEY": "v2",
|
||||||
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
|
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
|
||||||
|
@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]:
|
def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]:
|
||||||
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
|
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
|
||||||
|
|
||||||
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
|
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
|
||||||
|
@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _export(
|
def _export(
|
||||||
model: Dashboard, export_related: bool = True
|
model: Dashboard, export_related: bool = True
|
||||||
) -> Iterator[Tuple[str, str]]:
|
) -> Iterator[tuple[str, str]]:
|
||||||
file_name = get_filename(model.dashboard_title, model.id)
|
file_name = get_filename(model.dashboard_title, model.id)
|
||||||
file_path = f"dashboards/{file_name}.yaml"
|
file_path = f"dashboards/{file_name}.yaml"
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand):
|
||||||
until it finds one that matches.
|
until it finds one that matches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
|
@ -19,7 +19,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
from sqlalchemy.orm import make_transient, Session
|
from sqlalchemy.orm import make_transient, Session
|
||||||
|
@ -83,7 +83,7 @@ def import_chart(
|
||||||
def import_dashboard(
|
def import_dashboard(
|
||||||
# pylint: disable=too-many-locals,too-many-statements
|
# pylint: disable=too-many-locals,too-many-statements
|
||||||
dashboard_to_import: Dashboard,
|
dashboard_to_import: Dashboard,
|
||||||
dataset_id_mapping: Optional[Dict[int, int]] = None,
|
dataset_id_mapping: Optional[dict[int, int]] = None,
|
||||||
import_time: Optional[int] = None,
|
import_time: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Imports the dashboard from the object to the database.
|
"""Imports the dashboard from the object to the database.
|
||||||
|
@ -97,7 +97,7 @@ def import_dashboard(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def alter_positions(
|
def alter_positions(
|
||||||
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
|
dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Updates slice_ids in the position json.
|
"""Updates slice_ids in the position json.
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ def import_dashboard(
|
||||||
dashboard_to_import.slug = None
|
dashboard_to_import.slug = None
|
||||||
|
|
||||||
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
|
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
|
||||||
old_to_new_slc_id_dict: Dict[int, int] = {}
|
old_to_new_slc_id_dict: dict[int, int] = {}
|
||||||
new_timed_refresh_immune_slices = []
|
new_timed_refresh_immune_slices = []
|
||||||
new_expanded_slices = {}
|
new_expanded_slices = {}
|
||||||
new_filter_scopes = {}
|
new_filter_scopes = {}
|
||||||
|
@ -268,7 +268,7 @@ def import_dashboard(
|
||||||
return dashboard_to_import.id # type: ignore
|
return dashboard_to_import.id # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def decode_dashboards(o: Dict[str, Any]) -> Any:
|
def decode_dashboards(o: dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Function to be passed into json.loads obj_hook parameter
|
Function to be passed into json.loads obj_hook parameter
|
||||||
Recreates the dashboard object from a json representation.
|
Recreates the dashboard object from a json representation.
|
||||||
|
@ -302,7 +302,7 @@ def import_dashboards(
|
||||||
data = json.loads(content, object_hook=decode_dashboards)
|
data = json.loads(content, object_hook=decode_dashboards)
|
||||||
if not data:
|
if not data:
|
||||||
raise DashboardImportException(_("No data in file"))
|
raise DashboardImportException(_("No data in file"))
|
||||||
dataset_id_mapping: Dict[int, int] = {}
|
dataset_id_mapping: dict[int, int] = {}
|
||||||
for table in data["datasources"]:
|
for table in data["datasources"]:
|
||||||
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
|
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
|
||||||
params = json.loads(table.params)
|
params = json.loads(table.params)
|
||||||
|
@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand):
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def __init__(
|
def __init__(
|
||||||
self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any
|
self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any
|
||||||
):
|
):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.database_id = database_id
|
self.database_id = database_id
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
||||||
dao = DashboardDAO
|
dao = DashboardDAO
|
||||||
model_name = "dashboard"
|
model_name = "dashboard"
|
||||||
prefix = "dashboards/"
|
prefix = "dashboards/"
|
||||||
schemas: Dict[str, Schema] = {
|
schemas: dict[str, Schema] = {
|
||||||
"charts/": ImportV1ChartSchema(),
|
"charts/": ImportV1ChartSchema(),
|
||||||
"dashboards/": ImportV1DashboardSchema(),
|
"dashboards/": ImportV1DashboardSchema(),
|
||||||
"datasets/": ImportV1DatasetSchema(),
|
"datasets/": ImportV1DatasetSchema(),
|
||||||
|
@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
||||||
# pylint: disable=too-many-branches, too-many-locals
|
# pylint: disable=too-many-branches, too-many-locals
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(
|
def _import(
|
||||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
# discover charts and datasets associated with dashboards
|
# discover charts and datasets associated with dashboards
|
||||||
chart_uuids: Set[str] = set()
|
chart_uuids: set[str] = set()
|
||||||
dataset_uuids: Set[str] = set()
|
dataset_uuids: set[str] = set()
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("dashboards/"):
|
if file_name.startswith("dashboards/"):
|
||||||
chart_uuids.update(find_chart_uuids(config["position"]))
|
chart_uuids.update(find_chart_uuids(config["position"]))
|
||||||
|
@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
||||||
dataset_uuids.add(config["dataset_uuid"])
|
dataset_uuids.add(config["dataset_uuid"])
|
||||||
|
|
||||||
# discover databases associated with datasets
|
# discover databases associated with datasets
|
||||||
database_uuids: Set[str] = set()
|
database_uuids: set[str] = set()
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
|
||||||
database_uuids.add(config["database_uuid"])
|
database_uuids.add(config["database_uuid"])
|
||||||
|
|
||||||
# import related databases
|
# import related databases
|
||||||
database_ids: Dict[str, int] = {}
|
database_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
||||||
database = import_database(session, config, overwrite=False)
|
database = import_database(session, config, overwrite=False)
|
||||||
database_ids[str(database.uuid)] = database.id
|
database_ids[str(database.uuid)] = database.id
|
||||||
|
|
||||||
# import datasets with the correct parent ref
|
# import datasets with the correct parent ref
|
||||||
dataset_info: Dict[str, Dict[str, Any]] = {}
|
dataset_info: dict[str, dict[str, Any]] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if (
|
if (
|
||||||
file_name.startswith("datasets/")
|
file_name.startswith("datasets/")
|
||||||
|
@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
||||||
}
|
}
|
||||||
|
|
||||||
# import charts with the correct parent ref
|
# import charts with the correct parent ref
|
||||||
chart_ids: Dict[str, int] = {}
|
chart_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if (
|
if (
|
||||||
file_name.startswith("charts/")
|
file_name.startswith("charts/")
|
||||||
|
@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# import dashboards
|
# import dashboards
|
||||||
dashboard_chart_ids: List[Tuple[int, int]] = []
|
dashboard_chart_ids: list[tuple[int, int]] = []
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("dashboards/"):
|
if file_name.startswith("dashboards/"):
|
||||||
config = update_id_refs(config, chart_ids, dataset_info)
|
config = update_id_refs(config, chart_ids, dataset_info)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Set
|
from typing import Any
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -32,12 +32,12 @@ logger = logging.getLogger(__name__)
|
||||||
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
|
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
|
||||||
|
|
||||||
|
|
||||||
def find_chart_uuids(position: Dict[str, Any]) -> Set[str]:
|
def find_chart_uuids(position: dict[str, Any]) -> set[str]:
|
||||||
return set(build_uuid_to_id_map(position))
|
return set(build_uuid_to_id_map(position))
|
||||||
|
|
||||||
|
|
||||||
def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
|
def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]:
|
||||||
uuids: Set[str] = set()
|
uuids: set[str] = set()
|
||||||
for native_filter in metadata.get("native_filter_configuration", []):
|
for native_filter in metadata.get("native_filter_configuration", []):
|
||||||
targets = native_filter.get("targets", [])
|
targets = native_filter.get("targets", [])
|
||||||
for target in targets:
|
for target in targets:
|
||||||
|
@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
|
||||||
return uuids
|
return uuids
|
||||||
|
|
||||||
|
|
||||||
def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
|
def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]:
|
||||||
return {
|
return {
|
||||||
child["meta"]["uuid"]: child["meta"]["chartId"]
|
child["meta"]["uuid"]: child["meta"]["chartId"]
|
||||||
for child in position.values()
|
for child in position.values()
|
||||||
|
@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
|
||||||
|
|
||||||
|
|
||||||
def update_id_refs( # pylint: disable=too-many-locals
|
def update_id_refs( # pylint: disable=too-many-locals
|
||||||
config: Dict[str, Any],
|
config: dict[str, Any],
|
||||||
chart_ids: Dict[str, int],
|
chart_ids: dict[str, int],
|
||||||
dataset_info: Dict[str, Dict[str, Any]],
|
dataset_info: dict[str, dict[str, Any]],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Update dashboard metadata to use new IDs"""
|
"""Update dashboard metadata to use new IDs"""
|
||||||
fixed = config.copy()
|
fixed = config.copy()
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals
|
||||||
|
|
||||||
def import_dashboard(
|
def import_dashboard(
|
||||||
session: Session,
|
session: Session,
|
||||||
config: Dict[str, Any],
|
config: dict[str, Any],
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
ignore_permissions: bool = False,
|
ignore_permissions: bool = False,
|
||||||
) -> Dashboard:
|
) -> Dashboard:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[Dashboard] = None
|
self._model: Optional[Dashboard] = None
|
||||||
|
@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
owners_ids: Optional[List[int]] = self._properties.get("owners")
|
owners_ids: Optional[list[int]] = self._properties.get("owners")
|
||||||
roles_ids: Optional[List[int]] = self._properties.get("roles")
|
roles_ids: Optional[list[int]] = self._properties.get("roles")
|
||||||
slug: Optional[str] = self._properties.get("slug")
|
slug: Optional[str] = self._properties.get("slug")
|
||||||
|
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||||
|
@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO):
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
|
def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]:
|
||||||
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
|
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
|
||||||
return dashboard.datasets_trimmed_for_slices()
|
return dashboard.datasets_trimmed_for_slices()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]:
|
def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]:
|
||||||
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
|
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
|
def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None:
|
||||||
item_ids = [model.id for model in models] if models else []
|
item_ids = [model.id for model in models] if models else []
|
||||||
# bulk delete, first delete related data
|
# bulk delete, first delete related data
|
||||||
if models:
|
if models:
|
||||||
|
@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_dash_metadata( # pylint: disable=too-many-locals
|
def set_dash_metadata( # pylint: disable=too-many-locals
|
||||||
dashboard: Dashboard,
|
dashboard: Dashboard,
|
||||||
data: Dict[Any, Any],
|
data: dict[Any, Any],
|
||||||
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
|
old_to_new_slice_ids: Optional[dict[int, int]] = None,
|
||||||
commit: bool = False,
|
commit: bool = False,
|
||||||
) -> Dashboard:
|
) -> Dashboard:
|
||||||
new_filter_scopes = {}
|
new_filter_scopes = {}
|
||||||
|
@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO):
|
||||||
if "filter_scopes" in data:
|
if "filter_scopes" in data:
|
||||||
# replace filter_id and immune ids from old slice id to new slice id:
|
# replace filter_id and immune ids from old slice id to new slice id:
|
||||||
# and remove slice ids that are not in dash anymore
|
# and remove slice ids that are not in dash anymore
|
||||||
slc_id_dict: Dict[int, int] = {}
|
slc_id_dict: dict[int, int] = {}
|
||||||
if old_to_new_slice_ids:
|
if old_to_new_slice_ids:
|
||||||
slc_id_dict = {
|
slc_id_dict = {
|
||||||
old: new
|
old: new
|
||||||
|
@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO):
|
||||||
return dashboard
|
return dashboard
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
|
def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
|
||||||
ids = [dash.id for dash in dashboards]
|
ids = [dash.id for dash in dashboards]
|
||||||
return [
|
return [
|
||||||
star.obj_id
|
star.obj_id
|
||||||
|
@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copy_dashboard(
|
def copy_dashboard(
|
||||||
cls, original_dash: Dashboard, data: Dict[str, Any]
|
cls, original_dash: Dashboard, data: dict[str, Any]
|
||||||
) -> Dashboard:
|
) -> Dashboard:
|
||||||
dash = Dashboard()
|
dash = Dashboard()
|
||||||
dash.owners = [g.user] if g.user else []
|
dash.owners = [g.user] if g.user else []
|
||||||
|
@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO):
|
||||||
dash.css = data.get("css")
|
dash.css = data.get("css")
|
||||||
|
|
||||||
metadata = json.loads(data["json_metadata"])
|
metadata = json.loads(data["json_metadata"])
|
||||||
old_to_new_slice_ids: Dict[int, int] = {}
|
old_to_new_slice_ids: dict[int, int] = {}
|
||||||
if data.get("duplicate_slices"):
|
if data.get("duplicate_slices"):
|
||||||
# Duplicating slices as well, mapping old ids to new ones
|
# Duplicating slices as well, mapping old ids to new ones
|
||||||
for slc in original_dash.slices:
|
for slc in original_dash.slices:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CreateFilterSetCommand(BaseFilterSetCommand):
|
class CreateFilterSetCommand(BaseFilterSetCommand):
|
||||||
# pylint: disable=C0103
|
# pylint: disable=C0103
|
||||||
def __init__(self, dashboard_id: int, data: Dict[str, Any]):
|
def __init__(self, dashboard_id: int, data: dict[str, Any]):
|
||||||
super().__init__(dashboard_id)
|
super().__init__(dashboard_id)
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateFilterSetCommand(BaseFilterSetCommand):
|
class UpdateFilterSetCommand(BaseFilterSetCommand):
|
||||||
def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
|
def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]):
|
||||||
super().__init__(dashboard_id)
|
super().__init__(dashboard_id)
|
||||||
self._filter_set_id = filter_set_id
|
self._filter_set_id = filter_set_id
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO):
|
||||||
model_cls = FilterSet
|
model_cls = FilterSet
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||||
if cls.model_cls is None:
|
if cls.model_cls is None:
|
||||||
raise DAOConfigError()
|
raise DAOConfigError()
|
||||||
model = FilterSet()
|
model = FilterSet()
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, cast, Dict, Mapping
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from marshmallow import fields, post_load, Schema, ValidationError
|
from marshmallow import fields, post_load, Schema, ValidationError
|
||||||
from marshmallow.validate import Length, OneOf
|
from marshmallow.validate import Length, OneOf
|
||||||
|
@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema):
|
||||||
@post_load
|
@post_load
|
||||||
def validate(
|
def validate(
|
||||||
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
||||||
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
|
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
|
||||||
raise ValidationError("owner_id is mandatory when owner_type is User")
|
raise ValidationError("owner_id is mandatory when owner_type is User")
|
||||||
return cast(Dict[str, Any], data)
|
return cast(dict[str, Any], data)
|
||||||
|
|
||||||
|
|
||||||
class FilterSetPutSchema(FilterSetSchema):
|
class FilterSetPutSchema(FilterSetSchema):
|
||||||
|
@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema):
|
||||||
@post_load
|
@post_load
|
||||||
def validate( # pylint: disable=unused-argument
|
def validate( # pylint: disable=unused-argument
|
||||||
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
self, data: Mapping[Any, Any], *, many: Any, partial: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if JSON_METADATA_FIELD in data:
|
if JSON_METADATA_FIELD in data:
|
||||||
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
|
||||||
return cast(Dict[str, Any], data)
|
return cast(dict[str, Any], data)
|
||||||
|
|
||||||
|
|
||||||
def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None:
|
def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None:
|
||||||
if first_field in data and second_field not in data:
|
if first_field in data and second_field not in data:
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
"{} must be included alongside {}".format(first_field, second_field)
|
f"{first_field} must be included alongside {second_field}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_appbuilder.api import expose, protect, safe
|
from flask_appbuilder.api import expose, protect, safe
|
||||||
|
@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi):
|
||||||
resource_name = "dashboard"
|
resource_name = "dashboard"
|
||||||
openapi_spec_tag = "Dashboard Filter State"
|
openapi_spec_tag = "Dashboard Filter State"
|
||||||
|
|
||||||
def get_create_command(self) -> Type[CreateFilterStateCommand]:
|
def get_create_command(self) -> type[CreateFilterStateCommand]:
|
||||||
return CreateFilterStateCommand
|
return CreateFilterStateCommand
|
||||||
|
|
||||||
def get_update_command(self) -> Type[UpdateFilterStateCommand]:
|
def get_update_command(self) -> type[UpdateFilterStateCommand]:
|
||||||
return UpdateFilterStateCommand
|
return UpdateFilterStateCommand
|
||||||
|
|
||||||
def get_get_command(self) -> Type[GetFilterStateCommand]:
|
def get_get_command(self) -> type[GetFilterStateCommand]:
|
||||||
return GetFilterStateCommand
|
return GetFilterStateCommand
|
||||||
|
|
||||||
def get_delete_command(self) -> Type[DeleteFilterStateCommand]:
|
def get_delete_command(self) -> type[DeleteFilterStateCommand]:
|
||||||
return DeleteFilterStateCommand
|
return DeleteFilterStateCommand
|
||||||
|
|
||||||
@expose("/<int:pk>/filter_state", methods=("POST",))
|
@expose("/<int:pk>/filter_state", methods=("POST",))
|
||||||
|
|
|
@ -14,14 +14,14 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict
|
from typing import Any, Optional, TypedDict
|
||||||
|
|
||||||
|
|
||||||
class DashboardPermalinkState(TypedDict):
|
class DashboardPermalinkState(TypedDict):
|
||||||
dataMask: Optional[Dict[str, Any]]
|
dataMask: Optional[dict[str, Any]]
|
||||||
activeTabs: Optional[List[str]]
|
activeTabs: Optional[list[str]]
|
||||||
anchor: Optional[str]
|
anchor: Optional[str]
|
||||||
urlParams: Optional[List[Tuple[str, str]]]
|
urlParams: Optional[list[tuple[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
class DashboardPermalinkValue(TypedDict):
|
class DashboardPermalinkValue(TypedDict):
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from marshmallow import fields, post_load, pre_load, Schema
|
from marshmallow import fields, post_load, pre_load, Schema
|
||||||
from marshmallow.validate import Length, ValidationError
|
from marshmallow.validate import Length, ValidationError
|
||||||
|
@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema):
|
||||||
@pre_load
|
@pre_load
|
||||||
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
|
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
|
||||||
self,
|
self,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Remove ``show_native_filters`` from the JSON metadata.
|
Remove ``show_native_filters`` from the JSON metadata.
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema):
|
||||||
class BaseDashboardSchema(Schema):
|
class BaseDashboardSchema(Schema):
|
||||||
# pylint: disable=no-self-use,unused-argument
|
# pylint: disable=no-self-use,unused-argument
|
||||||
@post_load
|
@post_load
|
||||||
def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
||||||
if data.get("slug"):
|
if data.get("slug"):
|
||||||
data["slug"] = data["slug"].strip()
|
data["slug"] = data["slug"].strip()
|
||||||
data["slug"] = data["slug"].replace(" ", "-")
|
data["slug"] = data["slug"].replace(" ", "-")
|
||||||
|
|
|
@ -19,7 +19,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, cast, Dict, List, Optional
|
from typing import Any, cast, Optional
|
||||||
from zipfile import is_zipfile, ZipFile
|
from zipfile import is_zipfile, ZipFile
|
||||||
|
|
||||||
from flask import request, Response, send_file
|
from flask import request, Response, send_file
|
||||||
|
@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
500:
|
500:
|
||||||
$ref: '#/components/responses/500'
|
$ref: '#/components/responses/500'
|
||||||
"""
|
"""
|
||||||
preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
|
preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", [])
|
||||||
available_databases = []
|
available_databases = []
|
||||||
for engine_spec, drivers in get_available_engine_specs().items():
|
for engine_spec, drivers in get_available_engine_specs().items():
|
||||||
if not drivers:
|
if not drivers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload: Dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"name": engine_spec.engine_name,
|
"name": engine_spec.engine_name,
|
||||||
"engine": engine_spec.engine,
|
"engine": engine_spec.engine,
|
||||||
"available_drivers": sorted(drivers),
|
"available_drivers": sorted(drivers),
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"]
|
||||||
|
|
||||||
|
|
||||||
class CreateDatabaseCommand(BaseCommand):
|
class CreateDatabaseCommand(BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
|
@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand):
|
||||||
return database
|
return database
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
|
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
|
||||||
database_name: Optional[str] = self._properties.get("database_name")
|
database_name: Optional[str] = self._properties.get("database_name")
|
||||||
if not sqlalchemy_uri:
|
if not sqlalchemy_uri:
|
||||||
|
|
|
@ -18,7 +18,8 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterator, Tuple
|
from typing import Any
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_extra(extra_payload: str) -> Dict[str, Any]:
|
def parse_extra(extra_payload: str) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
extra = json.loads(extra_payload)
|
extra = json.loads(extra_payload)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
|
@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _export(
|
def _export(
|
||||||
model: Database, export_related: bool = True
|
model: Database, export_related: bool = True
|
||||||
) -> Iterator[Tuple[str, str]]:
|
) -> Iterator[tuple[str, str]]:
|
||||||
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
|
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
|
||||||
file_path = f"databases/{db_file_name}.yaml"
|
file_path = f"databases/{db_file_name}.yaml"
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow.exceptions import ValidationError
|
from marshmallow.exceptions import ValidationError
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand):
|
||||||
until it finds one that matches.
|
until it finds one that matches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||||
self.contents = contents
|
self.contents = contents
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from marshmallow import Schema
|
from marshmallow import Schema
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand):
|
||||||
dao = DatabaseDAO
|
dao = DatabaseDAO
|
||||||
model_name = "database"
|
model_name = "database"
|
||||||
prefix = "databases/"
|
prefix = "databases/"
|
||||||
schemas: Dict[str, Schema] = {
|
schemas: dict[str, Schema] = {
|
||||||
"databases/": ImportV1DatabaseSchema(),
|
"databases/": ImportV1DatabaseSchema(),
|
||||||
"datasets/": ImportV1DatasetSchema(),
|
"datasets/": ImportV1DatasetSchema(),
|
||||||
}
|
}
|
||||||
|
@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import(
|
def _import(
|
||||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
# first import databases
|
# first import databases
|
||||||
database_ids: Dict[str, int] = {}
|
database_ids: dict[str, int] = {}
|
||||||
for file_name, config in configs.items():
|
for file_name, config in configs.items():
|
||||||
if file_name.startswith("databases/"):
|
if file_name.startswith("databases/"):
|
||||||
database = import_database(session, config, overwrite=overwrite)
|
database = import_database(session, config, overwrite=overwrite)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from superset.models.core import Database
|
||||||
|
|
||||||
def import_database(
|
def import_database(
|
||||||
session: Session,
|
session: Session,
|
||||||
config: Dict[str, Any],
|
config: dict[str, Any],
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
ignore_permissions: bool = False,
|
ignore_permissions: bool = False,
|
||||||
) -> Database:
|
) -> Database:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast, Dict
|
from typing import Any, cast
|
||||||
|
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand):
|
||||||
self._schema_name = schema_name
|
self._schema_name = schema_name
|
||||||
self._force = force
|
self._force = force
|
||||||
|
|
||||||
def run(self) -> Dict[str, Any]:
|
def run(self) -> dict[str, Any]:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
tables = security_manager.get_datasources_accessible_by_user(
|
tables = security_manager.get_datasources_accessible_by_user(
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import current_app as app
|
from flask import current_app as app
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
|
@ -64,7 +64,7 @@ def get_log_connection_action(
|
||||||
|
|
||||||
|
|
||||||
class TestConnectionDatabaseCommand(BaseCommand):
|
class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
def __init__(self, data: Dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[Database] = None
|
self._model: Optional[Database] = None
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateDatabaseCommand(BaseCommand):
|
class UpdateDatabaseCommand(BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[Database] = None
|
self._model: Optional[Database] = None
|
||||||
|
@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
raise DatabaseConnectionFailedError() from ex
|
raise DatabaseConnectionFailedError() from ex
|
||||||
|
|
||||||
# Update database schema permissions
|
# Update database schema permissions
|
||||||
new_schemas: List[str] = []
|
new_schemas: list[str] = []
|
||||||
|
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
old_view_menu_name = security_manager.get_schema_perm(
|
old_view_menu_name = security_manager.get_schema_perm(
|
||||||
|
@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
chart.schema_perm = new_view_menu_name
|
chart.schema_perm = new_view_menu_name
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||||
if not self._model:
|
if not self._model:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import json
|
import json
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
|
||||||
|
|
||||||
|
|
||||||
class ValidateDatabaseParametersCommand(BaseCommand):
|
class ValidateDatabaseParametersCommand(BaseCommand):
|
||||||
def __init__(self, properties: Dict[str, Any]):
|
def __init__(self, properties: dict[str, Any]):
|
||||||
self._properties = properties.copy()
|
self._properties = properties.copy()
|
||||||
self._model: Optional[Database] = None
|
self._model: Optional[Database] = None
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ValidateSQLCommand(BaseCommand):
|
class ValidateSQLCommand(BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[Database] = None
|
self._model: Optional[Database] = None
|
||||||
self._validator: Optional[Type[BaseSQLValidator]] = None
|
self._validator: Optional[type[BaseSQLValidator]] = None
|
||||||
|
|
||||||
def run(self) -> List[Dict[str, Any]]:
|
def run(self) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Validates a SQL statement
|
Validates a SQL statement
|
||||||
|
|
||||||
|
@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand):
|
||||||
if not validators_by_engine or spec.engine not in validators_by_engine:
|
if not validators_by_engine or spec.engine not in validators_by_engine:
|
||||||
raise NoValidatorConfigFoundError(
|
raise NoValidatorConfigFoundError(
|
||||||
SupersetError(
|
SupersetError(
|
||||||
message=__(
|
message=__(f"no SQL validator is configured for {spec.engine}"),
|
||||||
"no SQL validator is configured for {}".format(spec.engine)
|
|
||||||
),
|
|
||||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||||
level=ErrorLevel.ERROR,
|
level=ErrorLevel.ERROR,
|
||||||
),
|
),
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from superset.dao.base import BaseDAO
|
from superset.dao.base import BaseDAO
|
||||||
from superset.databases.filters import DatabaseFilter
|
from superset.databases.filters import DatabaseFilter
|
||||||
|
@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO):
|
||||||
def update(
|
def update(
|
||||||
cls,
|
cls,
|
||||||
model: Database,
|
model: Database,
|
||||||
properties: Dict[str, Any],
|
properties: dict[str, Any],
|
||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> Database:
|
) -> Database:
|
||||||
"""
|
"""
|
||||||
|
@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
|
def get_related_objects(cls, database_id: int) -> dict[str, Any]:
|
||||||
database: Any = cls.find_by_id(database_id)
|
database: Any = cls.find_by_id(database_id)
|
||||||
datasets = database.tables
|
datasets = database.tables
|
||||||
dataset_ids = [dataset.id for dataset in datasets]
|
dataset_ids = [dataset.id for dataset in datasets]
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Any, Set
|
from typing import Any
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
@ -30,7 +30,7 @@ from superset.views.base import BaseFilter
|
||||||
|
|
||||||
def can_access_databases(
|
def can_access_databases(
|
||||||
view_menu_name: str,
|
view_menu_name: str,
|
||||||
) -> Set[str]:
|
) -> set[str]:
|
||||||
return {
|
return {
|
||||||
security_manager.unpack_database_and_schema(vm).database
|
security_manager.unpack_database_and_schema(vm).database
|
||||||
for vm in security_manager.user_view_menu_names(view_menu_name)
|
for vm in security_manager.user_view_menu_names(view_menu_name)
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
|
@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
@pre_load
|
@pre_load
|
||||||
def build_sqlalchemy_uri(
|
def build_sqlalchemy_uri(
|
||||||
self, data: Dict[str, Any], **kwargs: Any
|
self, data: dict[str, Any], **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build SQLAlchemy URI from separate parameters.
|
Build SQLAlchemy URI from separate parameters.
|
||||||
|
|
||||||
|
@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
def rename_encrypted_extra(
|
def rename_encrypted_extra(
|
||||||
self: Schema,
|
self: Schema,
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
|
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
|
||||||
|
|
||||||
|
@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema):
|
||||||
class ImportV1DatabaseExtraSchema(Schema):
|
class ImportV1DatabaseExtraSchema(Schema):
|
||||||
@pre_load
|
@pre_load
|
||||||
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
|
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
|
||||||
self, data: Dict[str, Any], **kwargs: Any
|
self, data: dict[str, Any], **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Fixes for ``schemas_allowed_for_csv_upload``.
|
Fixes for ``schemas_allowed_for_csv_upload``.
|
||||||
"""
|
"""
|
||||||
|
@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema):
|
||||||
class ImportV1DatabaseSchema(Schema):
|
class ImportV1DatabaseSchema(Schema):
|
||||||
@pre_load
|
@pre_load
|
||||||
def fix_allow_csv_upload(
|
def fix_allow_csv_upload(
|
||||||
self, data: Dict[str, Any], **kwargs: Any
|
self, data: dict[str, Any], **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Fix for ``allow_csv_upload`` .
|
Fix for ``allow_csv_upload`` .
|
||||||
"""
|
"""
|
||||||
|
@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema):
|
||||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||||
|
|
||||||
@validates_schema
|
@validates_schema
|
||||||
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
|
def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""If sqlalchemy_uri has a masked password, password is required"""
|
"""If sqlalchemy_uri has a masked password, password is required"""
|
||||||
uuid = data["uuid"]
|
uuid = data["uuid"]
|
||||||
existing = db.session.query(Database).filter_by(uuid=uuid).first()
|
existing = db.session.query(Database).filter_by(uuid=uuid).first()
|
||||||
|
@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema):
|
||||||
|
|
||||||
@validates_schema
|
@validates_schema
|
||||||
def validate_ssh_tunnel_credentials(
|
def validate_ssh_tunnel_credentials(
|
||||||
self, data: Dict[str, Any], **kwargs: Any
|
self, data: dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""If ssh_tunnel has a masked credentials, credentials are required"""
|
"""If ssh_tunnel has a masked credentials, credentials are required"""
|
||||||
uuid = data["uuid"]
|
uuid = data["uuid"]
|
||||||
|
@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema):
|
||||||
# or there're times where it's masked.
|
# or there're times where it's masked.
|
||||||
# If both are masked, we need to return a list of errors
|
# If both are masked, we need to return a list of errors
|
||||||
# so the UI ask for both fields at the same time if needed
|
# so the UI ask for both fields at the same time if needed
|
||||||
exception_messages: List[str] = []
|
exception_messages: list[str] = []
|
||||||
if private_key is None or private_key == PASSWORD_MASK:
|
if private_key is None or private_key == PASSWORD_MASK:
|
||||||
# If we get here we need to ask for the private key
|
# If we get here we need to ask for the private key
|
||||||
exception_messages.append(
|
exception_messages.append(
|
||||||
|
@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore
|
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
|
||||||
ret = {}
|
ret = {}
|
||||||
if isinstance(field, EncryptedField):
|
if isinstance(field, EncryptedField):
|
||||||
if self.openapi_version.major > 2:
|
if self.openapi_version.major > 2:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateSSHTunnelCommand(BaseCommand):
|
class CreateSSHTunnelCommand(BaseCommand):
|
||||||
def __init__(self, database_id: int, data: Dict[str, Any]):
|
def __init__(self, database_id: int, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._properties["database_id"] = database_id
|
self._properties["database_id"] = database_id
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand):
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# TODO(hughhh): check to make sure the server port is not localhost
|
# TODO(hughhh): check to make sure the server port is not localhost
|
||||||
# using the config.SSH_TUNNEL_MANAGER
|
# using the config.SSH_TUNNEL_MANAGER
|
||||||
exceptions: List[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
database_id: Optional[int] = self._properties.get("database_id")
|
database_id: Optional[int] = self._properties.get("database_id")
|
||||||
server_address: Optional[str] = self._properties.get("server_address")
|
server_address: Optional[str] = self._properties.get("server_address")
|
||||||
server_port: Optional[int] = self._properties.get("server_port")
|
server_port: Optional[int] = self._properties.get("server_port")
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateSSHTunnelCommand(BaseCommand):
|
class UpdateSSHTunnelCommand(BaseCommand):
|
||||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[SSHTunnel] = None
|
self._model: Optional[SSHTunnel] = None
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from superset.dao.base import BaseDAO
|
from superset.dao.base import BaseDAO
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO):
|
||||||
def update(
|
def update(
|
||||||
cls,
|
cls,
|
||||||
model: SSHTunnel,
|
model: SSHTunnel,
|
||||||
properties: Dict[str, Any],
|
properties: dict[str, Any],
|
||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> SSHTunnel:
|
) -> SSHTunnel:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> Dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
output = {
|
output = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"server_address": self.server_address,
|
"server_address": self.server_address,
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue