mirror of https://github.com/apache/superset.git
[mypy] Enforcing typing for superset.models (#9883)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
6d4e23663e
commit
e789a35558
|
@ -53,7 +53,7 @@ order_by_type = false
|
|||
ignore_missing_imports = true
|
||||
no_implicit_optional = true
|
||||
|
||||
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
|
||||
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import logging
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
@ -103,7 +103,11 @@ class AnnotationDatasource(BaseDatasource):
|
|||
logger.exception(ex)
|
||||
error_message = utils.error_msg_from_exception(ex)
|
||||
return QueryResult(
|
||||
status=status, df=df, duration=0, query="", error_message=error_message
|
||||
status=status,
|
||||
df=df,
|
||||
duration=timedelta(0),
|
||||
query="",
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def get_query_str(self, query_obj):
|
||||
|
|
|
@ -15,9 +15,10 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Code related with dealing with legacy / change management"""
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def update_time_range(form_data):
|
||||
def update_time_range(form_data: Dict[str, Any]) -> None:
|
||||
"""Move since and until to time_range."""
|
||||
if "since" in form_data or "until" in form_data:
|
||||
form_data["time_range"] = "{} : {}".format(
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""a collection of Annotation-related models"""
|
||||
from typing import Any, Dict
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
|
@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable):
|
|||
name = Column(String(250))
|
||||
descr = Column(Text)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
|
@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
|
|||
__table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"layer_id": self.layer_id,
|
||||
"start_dttm": self.start_dttm,
|
||||
|
|
|
@ -152,7 +152,7 @@ class Database(
|
|||
]
|
||||
export_children = ["tables"]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
|
@ -234,7 +234,9 @@ class Database(
|
|||
return self.get_extra().get("default_schemas", [])
|
||||
|
||||
@classmethod
|
||||
def get_password_masked_url_from_uri(cls, uri: str): # pylint: disable=invalid-name
|
||||
def get_password_masked_url_from_uri( # pylint: disable=invalid-name
|
||||
cls, uri: str
|
||||
) -> URL:
|
||||
sqlalchemy_url = make_url(uri)
|
||||
return cls.get_password_masked_url(sqlalchemy_url)
|
||||
|
||||
|
@ -279,7 +281,7 @@ class Database(
|
|||
effective_username = g.user.username
|
||||
return effective_username
|
||||
|
||||
@utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
|
||||
@utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted", "extra"])
|
||||
def get_sqla_engine(
|
||||
self,
|
||||
schema: Optional[str] = None,
|
||||
|
@ -339,7 +341,7 @@ class Database(
|
|||
def get_reserved_words(self) -> Set[str]:
|
||||
return self.get_dialect().preparer.reserved_words
|
||||
|
||||
def get_quoter(self):
|
||||
def get_quoter(self) -> Callable:
|
||||
return self.get_dialect().identifier_preparer.quote
|
||||
|
||||
def get_df( # pylint: disable=too-many-locals
|
||||
|
@ -405,7 +407,7 @@ class Database(
|
|||
indent: bool = True,
|
||||
latest_partition: bool = False,
|
||||
cols: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
) -> str:
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
||||
return self.db_engine_spec.select_star(
|
||||
|
@ -436,7 +438,10 @@ class Database(
|
|||
attribute_in_key="id",
|
||||
)
|
||||
def get_all_table_names_in_database(
|
||||
self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False
|
||||
self,
|
||||
cache: bool = False,
|
||||
cache_timeout: Optional[bool] = None,
|
||||
force: bool = False,
|
||||
) -> List[utils.DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments."""
|
||||
if not self.allow_multi_schema_metadata_fetch:
|
||||
|
@ -547,7 +552,7 @@ class Database(
|
|||
|
||||
@classmethod
|
||||
def get_db_engine_spec_for_backend(
|
||||
cls, backend
|
||||
cls, backend: str
|
||||
) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
|
@ -565,7 +570,7 @@ class Database(
|
|||
def get_extra(self) -> Dict[str, Any]:
|
||||
return self.db_engine_spec.get_extra_params(self)
|
||||
|
||||
def get_encrypted_extra(self):
|
||||
def get_encrypted_extra(self) -> Dict[str, Any]:
|
||||
encrypted_extra = {}
|
||||
if self.encrypted_extra:
|
||||
try:
|
||||
|
|
|
@ -36,7 +36,9 @@ from sqlalchemy import (
|
|||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin
|
||||
|
@ -59,7 +61,7 @@ config = app.config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def copy_dashboard(mapper, connection, target):
|
||||
def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
|
||||
# pylint: disable=unused-argument
|
||||
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
|
||||
if dashboard_id is None:
|
||||
|
@ -140,7 +142,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
"slug",
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.dashboard_title or str(self.id)
|
||||
|
||||
@property
|
||||
|
@ -202,13 +204,13 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"
|
||||
|
||||
@property
|
||||
def changed_by_name(self):
|
||||
def changed_by_name(self) -> str:
|
||||
if not self.changed_by:
|
||||
return ""
|
||||
return str(self.changed_by)
|
||||
|
||||
@property
|
||||
def changed_by_url(self):
|
||||
def changed_by_url(self) -> str:
|
||||
if not self.changed_by:
|
||||
return ""
|
||||
return f"/superset/profile/{self.changed_by.username}"
|
||||
|
@ -229,8 +231,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
"position_json": positions,
|
||||
}
|
||||
|
||||
@property
|
||||
def params(self) -> str:
|
||||
@property # type: ignore
|
||||
def params(self) -> str: # type: ignore
|
||||
return self.json_metadata
|
||||
|
||||
@params.setter
|
||||
|
@ -257,7 +259,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
Audit metadata isn't copied over.
|
||||
"""
|
||||
|
||||
def alter_positions(dashboard, old_to_new_slc_id_dict):
|
||||
def alter_positions(
|
||||
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
|
||||
) -> None:
|
||||
""" Updates slice_ids in the position json.
|
||||
|
||||
Sample position_json data:
|
||||
|
@ -291,9 +295,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
if (
|
||||
isinstance(value, dict)
|
||||
and value.get("meta")
|
||||
and value.get("meta").get("chartId")
|
||||
and value.get("meta", {}).get("chartId")
|
||||
):
|
||||
old_slice_id = value.get("meta").get("chartId")
|
||||
old_slice_id = value["meta"]["chartId"]
|
||||
|
||||
if old_slice_id in old_to_new_slc_id_dict:
|
||||
value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id]
|
||||
|
@ -470,8 +474,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
|
||||
|
||||
def event_after_dashboard_changed( # pylint: disable=unused-argument
|
||||
mapper, connection, target
|
||||
):
|
||||
mapper: Mapper, connection: Connection, target: Dashboard
|
||||
) -> None:
|
||||
cache_dashboard_thumbnail.delay(target.id, force=True)
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
# isort and pylint disagree, isort should win
|
||||
# pylint: disable=ungrouped-imports
|
||||
|
@ -30,8 +30,10 @@ import yaml
|
|||
from flask import escape, g, Markup
|
||||
from flask_appbuilder.models.decorators import renders
|
||||
from flask_appbuilder.models.mixins import AuditMixin
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import and_, or_, UniqueConstraint
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import MultipleResultsFound
|
||||
|
||||
from superset.utils.core import QueryStatus
|
||||
|
@ -39,7 +41,7 @@ from superset.utils.core import QueryStatus
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def json_to_dict(json_str):
|
||||
def json_to_dict(json_str: str) -> Dict[Any, Any]:
|
||||
if json_str:
|
||||
val = re.sub(",[ \t\r\n]+}", "}", json_str)
|
||||
val = re.sub(
|
||||
|
@ -64,48 +66,56 @@ class ImportMixin:
|
|||
# that are available for import and export
|
||||
|
||||
@classmethod
|
||||
def _parent_foreign_key_mappings(cls):
|
||||
def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
|
||||
"""Get a mapping of foreign name to the local name of foreign keys"""
|
||||
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
|
||||
parent_rel = cls.__mapper__.relationships.get(cls.export_parent) # type: ignore
|
||||
if parent_rel:
|
||||
return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _unique_constrains(cls):
|
||||
def _unique_constrains(cls) -> List[Set[str]]:
|
||||
"""Get all (single column and multi column) unique constraints"""
|
||||
unique = [
|
||||
{c.name for c in u.columns}
|
||||
for u in cls.__table_args__
|
||||
for u in cls.__table_args__ # type: ignore
|
||||
if isinstance(u, UniqueConstraint)
|
||||
]
|
||||
unique.extend({c.name} for c in cls.__table__.columns if c.unique)
|
||||
unique.extend( # type: ignore
|
||||
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
|
||||
)
|
||||
return unique
|
||||
|
||||
@classmethod
|
||||
def export_schema(cls, recursive=True, include_parent_ref=False):
|
||||
def export_schema(
|
||||
cls, recursive: bool = True, include_parent_ref: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Export schema as a dictionary"""
|
||||
parent_excludes = {}
|
||||
parent_excludes = set()
|
||||
if not include_parent_ref:
|
||||
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
|
||||
parent_ref = cls.__mapper__.relationships.get( # type: ignore
|
||||
cls.export_parent
|
||||
)
|
||||
if parent_ref:
|
||||
parent_excludes = {column.name for column in parent_ref.local_columns}
|
||||
|
||||
def formatter(column):
|
||||
def formatter(column: sa.Column) -> str:
|
||||
return (
|
||||
"{0} Default ({1})".format(str(column.type), column.default.arg)
|
||||
if column.default
|
||||
else str(column.type)
|
||||
)
|
||||
|
||||
schema = {
|
||||
schema: Dict[str, Any] = {
|
||||
column.name: formatter(column)
|
||||
for column in cls.__table__.columns
|
||||
for column in cls.__table__.columns # type: ignore
|
||||
if (column.name in cls.export_fields and column.name not in parent_excludes)
|
||||
}
|
||||
if recursive:
|
||||
for column in cls.export_children:
|
||||
child_class = cls.__mapper__.relationships[column].argument.class_
|
||||
child_class = cls.__mapper__.relationships[ # type: ignore
|
||||
column
|
||||
].argument.class_
|
||||
schema[column] = [
|
||||
child_class.export_schema(
|
||||
recursive=recursive, include_parent_ref=include_parent_ref
|
||||
|
@ -114,17 +124,20 @@ class ImportMixin:
|
|||
return schema
|
||||
|
||||
@classmethod
|
||||
def import_from_dict(
|
||||
cls, session, dict_rep, parent=None, recursive=True, sync=None
|
||||
): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
|
||||
def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
|
||||
cls,
|
||||
session: Session,
|
||||
dict_rep: Dict[Any, Any],
|
||||
parent: Optional[Any] = None,
|
||||
recursive: bool = True,
|
||||
sync: Optional[List[str]] = None,
|
||||
) -> Any: # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
|
||||
"""Import obj from a dictionary"""
|
||||
if sync is None:
|
||||
sync = []
|
||||
parent_refs = cls._parent_foreign_key_mappings()
|
||||
export_fields = set(cls.export_fields) | set(parent_refs.keys())
|
||||
new_children = {
|
||||
c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
|
||||
}
|
||||
new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep}
|
||||
unique_constrains = cls._unique_constrains()
|
||||
|
||||
filters = [] # Using these filters to check if obj already exists
|
||||
|
@ -178,7 +191,7 @@ class ImportMixin:
|
|||
if not obj:
|
||||
is_new_obj = True
|
||||
# Create new DB object
|
||||
obj = cls(**dict_rep)
|
||||
obj = cls(**dict_rep) # type: ignore
|
||||
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
|
||||
if cls.export_parent and parent:
|
||||
setattr(obj, cls.export_parent, parent)
|
||||
|
@ -193,7 +206,9 @@ class ImportMixin:
|
|||
# Recursively create children
|
||||
if recursive:
|
||||
for child in cls.export_children:
|
||||
child_class = cls.__mapper__.relationships[child].argument.class_
|
||||
child_class = cls.__mapper__.relationships[ # type: ignore
|
||||
child
|
||||
].argument.class_
|
||||
added = []
|
||||
for c_obj in new_children.get(child, []):
|
||||
added.append(
|
||||
|
@ -221,18 +236,23 @@ class ImportMixin:
|
|||
return obj
|
||||
|
||||
def export_to_dict(
|
||||
self, recursive=True, include_parent_ref=False, include_defaults=False
|
||||
):
|
||||
self,
|
||||
recursive: bool = True,
|
||||
include_parent_ref: bool = False,
|
||||
include_defaults: bool = False,
|
||||
) -> Dict[Any, Any]:
|
||||
"""Export obj to dictionary"""
|
||||
cls = self.__class__
|
||||
parent_excludes = {}
|
||||
parent_excludes = set()
|
||||
if recursive and not include_parent_ref:
|
||||
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
|
||||
parent_ref = cls.__mapper__.relationships.get( # type: ignore
|
||||
cls.export_parent
|
||||
)
|
||||
if parent_ref:
|
||||
parent_excludes = {c.name for c in parent_ref.local_columns}
|
||||
dict_rep = {
|
||||
c.name: getattr(self, c.name)
|
||||
for c in cls.__table__.columns
|
||||
for c in cls.__table__.columns # type: ignore
|
||||
if (
|
||||
c.name in self.export_fields
|
||||
and c.name not in parent_excludes
|
||||
|
@ -262,18 +282,18 @@ class ImportMixin:
|
|||
|
||||
return dict_rep
|
||||
|
||||
def override(self, obj):
|
||||
def override(self, obj: Any) -> None:
|
||||
"""Overrides the plain fields of the dashboard."""
|
||||
for field in obj.__class__.export_fields:
|
||||
setattr(self, field, getattr(obj, field))
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Any:
|
||||
"""Creates a copy of the dashboard without relationships."""
|
||||
new_obj = self.__class__()
|
||||
new_obj.override(self)
|
||||
return new_obj
|
||||
|
||||
def alter_params(self, **kwargs):
|
||||
def alter_params(self, **kwargs: Any) -> None:
|
||||
params = self.params_dict
|
||||
params.update(kwargs)
|
||||
self.params = json.dumps(params)
|
||||
|
@ -283,7 +303,7 @@ class ImportMixin:
|
|||
params.pop(param_to_remove, None)
|
||||
self.params = json.dumps(params)
|
||||
|
||||
def reset_ownership(self):
|
||||
def reset_ownership(self) -> None:
|
||||
""" object will belong to the user the current user """
|
||||
# make sure the object doesn't have relations to a user
|
||||
# it will be filled by appbuilder on save
|
||||
|
@ -297,15 +317,15 @@ class ImportMixin:
|
|||
self.owners = []
|
||||
|
||||
@property
|
||||
def params_dict(self):
|
||||
def params_dict(self) -> Dict[Any, Any]:
|
||||
return json_to_dict(self.params)
|
||||
|
||||
@property
|
||||
def template_params_dict(self):
|
||||
return json_to_dict(self.template_params)
|
||||
def template_params_dict(self) -> Dict[Any, Any]:
|
||||
return json_to_dict(self.template_params) # type: ignore
|
||||
|
||||
|
||||
def _user_link(user): # pylint: disable=no-self-use
|
||||
def _user_link(user: User) -> Union[Markup, str]: # pylint: disable=no-self-use
|
||||
if not user:
|
||||
return ""
|
||||
url = "/superset/profile/{}/".format(user.username)
|
||||
|
@ -325,7 +345,7 @@ class AuditMixinNullable(AuditMixin):
|
|||
)
|
||||
|
||||
@declared_attr
|
||||
def created_by_fk(self):
|
||||
def created_by_fk(self) -> sa.Column:
|
||||
return sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey("ab_user.id"),
|
||||
|
@ -334,7 +354,7 @@ class AuditMixinNullable(AuditMixin):
|
|||
)
|
||||
|
||||
@declared_attr
|
||||
def changed_by_fk(self):
|
||||
def changed_by_fk(self) -> sa.Column:
|
||||
return sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey("ab_user.id"),
|
||||
|
@ -343,29 +363,29 @@ class AuditMixinNullable(AuditMixin):
|
|||
nullable=True,
|
||||
)
|
||||
|
||||
def changed_by_name(self):
|
||||
def changed_by_name(self) -> str:
|
||||
if self.created_by:
|
||||
return escape("{}".format(self.created_by))
|
||||
return ""
|
||||
|
||||
@renders("created_by")
|
||||
def creator(self):
|
||||
def creator(self) -> Union[Markup, str]:
|
||||
return _user_link(self.created_by)
|
||||
|
||||
@property
|
||||
def changed_by_(self):
|
||||
def changed_by_(self) -> Union[Markup, str]:
|
||||
return _user_link(self.changed_by)
|
||||
|
||||
@renders("changed_on")
|
||||
def changed_on_(self):
|
||||
def changed_on_(self) -> Markup:
|
||||
return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
|
||||
|
||||
@property
|
||||
def changed_on_humanized(self):
|
||||
def changed_on_humanized(self) -> str:
|
||||
return humanize.naturaltime(datetime.now() - self.changed_on)
|
||||
|
||||
@renders("changed_on")
|
||||
def modified(self):
|
||||
def modified(self) -> Markup:
|
||||
return Markup(f'<span class="no-wrap">{self.changed_on_humanized}</span>')
|
||||
|
||||
|
||||
|
@ -375,19 +395,19 @@ class QueryResult: # pylint: disable=too-few-public-methods
|
|||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
df,
|
||||
query,
|
||||
duration,
|
||||
status=QueryStatus.SUCCESS,
|
||||
error_message=None,
|
||||
errors=None,
|
||||
):
|
||||
self.df: pd.DataFrame = df
|
||||
self.query: str = query
|
||||
self.duration: int = duration
|
||||
self.status: str = status
|
||||
self.error_message: Optional[str] = error_message
|
||||
self.errors: List[Dict[str, Any]] = errors or []
|
||||
df: pd.DataFrame,
|
||||
query: str,
|
||||
duration: timedelta,
|
||||
status: str = QueryStatus.SUCCESS,
|
||||
error_message: Optional[str] = None,
|
||||
errors: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
self.df = df
|
||||
self.query = query
|
||||
self.duration = duration
|
||||
self.status = status
|
||||
self.error_message = error_message
|
||||
self.errors = errors or []
|
||||
|
||||
|
||||
class ExtraJSONMixin:
|
||||
|
@ -396,16 +416,16 @@ class ExtraJSONMixin:
|
|||
extra_json = sa.Column(sa.Text, default="{}")
|
||||
|
||||
@property
|
||||
def extra(self):
|
||||
def extra(self) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(self.extra_json)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return {}
|
||||
|
||||
def set_extra_json(self, extras):
|
||||
def set_extra_json(self, extras: Dict[str, Any]) -> None:
|
||||
self.extra_json = json.dumps(extras)
|
||||
|
||||
def set_extra_json_key(self, key, value):
|
||||
def set_extra_json_key(self, key: str, value: Any) -> None:
|
||||
extra = self.extra
|
||||
extra[key] = value
|
||||
self.extra_json = json.dumps(extra)
|
||||
|
|
|
@ -21,7 +21,7 @@ from typing import Optional, Type
|
|||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, RelationshipProperty
|
||||
|
||||
from superset import security_manager
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin
|
||||
|
@ -55,11 +55,11 @@ class EmailSchedule:
|
|||
crontab = Column(String(50))
|
||||
|
||||
@declared_attr
|
||||
def user_id(self):
|
||||
def user_id(self) -> int:
|
||||
return Column(Integer, ForeignKey("ab_user.id"))
|
||||
|
||||
@declared_attr
|
||||
def user(self):
|
||||
def user(self) -> RelationshipProperty:
|
||||
return relationship(
|
||||
security_manager.user_model,
|
||||
backref=self.__tablename__,
|
||||
|
|
|
@ -24,7 +24,9 @@ from flask_appbuilder import Model
|
|||
from flask_appbuilder.models.decorators import renders
|
||||
from markupsafe import escape, Markup
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import make_transient, relationship
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
|
||||
from superset.legacy import update_time_range
|
||||
|
@ -92,7 +94,7 @@ class Slice(
|
|||
"cache_timeout",
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.slice_name or str(self.id)
|
||||
|
||||
@property
|
||||
|
@ -263,7 +265,7 @@ class Slice(
|
|||
|
||||
@property
|
||||
def changed_by_url(self) -> str:
|
||||
return f"/superset/profile/{self.created_by.username}"
|
||||
return f"/superset/profile/{self.created_by.username}" # type: ignore
|
||||
|
||||
@property
|
||||
def icons(self) -> str:
|
||||
|
@ -324,7 +326,7 @@ class Slice(
|
|||
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
|
||||
|
||||
|
||||
def set_related_perm(mapper, connection, target):
|
||||
def set_related_perm(mapper: Mapper, connection: Connection, target: Slice) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
src_class = target.cls_model
|
||||
id_ = target.datasource_id
|
||||
|
@ -336,8 +338,8 @@ def set_related_perm(mapper, connection, target):
|
|||
|
||||
|
||||
def event_after_chart_changed( # pylint: disable=unused-argument
|
||||
mapper, connection, target
|
||||
):
|
||||
mapper: Mapper, connection: Connection, target: Slice
|
||||
) -> None:
|
||||
cache_chart_thumbnail.delay(target.id, force=True)
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
"""A collection of ORM sqlalchemy models for SQL Lab"""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
# pylint: disable=ungrouped-imports
|
||||
import simplejson as json
|
||||
|
@ -33,6 +34,7 @@ from sqlalchemy import (
|
|||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
|
||||
from superset import security_manager
|
||||
|
@ -99,7 +101,7 @@ class Query(Model, ExtraJSONMixin):
|
|||
|
||||
__table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"changedOn": self.changed_on,
|
||||
"changed_on": self.changed_on.isoformat(),
|
||||
|
@ -130,7 +132,7 @@ class Query(Model, ExtraJSONMixin):
|
|||
}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Name property"""
|
||||
ts = datetime.now().isoformat()
|
||||
ts = ts.replace("-", "").replace(":", "").split(".")[0]
|
||||
|
@ -139,11 +141,11 @@ class Query(Model, ExtraJSONMixin):
|
|||
return f"sqllab_{tab}_{ts}"
|
||||
|
||||
@property
|
||||
def database_name(self):
|
||||
def database_name(self) -> str:
|
||||
return self.database.name
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
def username(self) -> str:
|
||||
return self.user.username
|
||||
|
||||
|
||||
|
@ -170,7 +172,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
|
|||
)
|
||||
|
||||
@property
|
||||
def pop_tab_link(self):
|
||||
def pop_tab_link(self) -> Markup:
|
||||
return Markup(
|
||||
f"""
|
||||
<a href="/superset/sqllab?savedQueryId={self.id}">
|
||||
|
@ -180,14 +182,14 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
|
|||
)
|
||||
|
||||
@property
|
||||
def user_email(self):
|
||||
def user_email(self) -> str:
|
||||
return self.user.email
|
||||
|
||||
@property
|
||||
def sqlalchemy_uri(self):
|
||||
def sqlalchemy_uri(self) -> URL:
|
||||
return self.database.sqlalchemy_uri
|
||||
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
return "/superset/sqllab?savedQueryId={0}".format(self.id)
|
||||
|
||||
|
||||
|
@ -226,7 +228,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin):
|
|||
autorun = Column(Boolean, default=False)
|
||||
template_params = Column(Text)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"user_id": self.user_id,
|
||||
|
@ -260,7 +262,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin):
|
|||
|
||||
expanded = Column(Boolean, default=False)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
try:
|
||||
description = json.loads(self.description)
|
||||
except json.JSONDecodeError:
|
||||
|
|
|
@ -14,10 +14,12 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.sql.sqltypes import Integer
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.sql.visitors import Visitable
|
||||
|
||||
# _compiler_dispatch is defined to help with type compilation
|
||||
|
||||
|
@ -27,11 +29,11 @@ class TinyInteger(Integer):
|
|||
A type for tiny ``int`` integers.
|
||||
"""
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> Type:
|
||||
return int
|
||||
|
||||
@classmethod
|
||||
def _compiler_dispatch(cls, _visitor, **_kw):
|
||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||
return "TINYINT"
|
||||
|
||||
|
||||
|
@ -40,11 +42,11 @@ class Interval(TypeEngine):
|
|||
A type for intervals.
|
||||
"""
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> Optional[Type]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _compiler_dispatch(cls, _visitor, **_kw):
|
||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||
return "INTERVAL"
|
||||
|
||||
|
||||
|
@ -53,11 +55,11 @@ class Array(TypeEngine):
|
|||
A type for arrays.
|
||||
"""
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> Optional[Type]:
|
||||
return list
|
||||
|
||||
@classmethod
|
||||
def _compiler_dispatch(cls, _visitor, **_kw):
|
||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||
return "ARRAY"
|
||||
|
||||
|
||||
|
@ -66,11 +68,11 @@ class Map(TypeEngine):
|
|||
A type for maps.
|
||||
"""
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> Optional[Type]:
|
||||
return dict
|
||||
|
||||
@classmethod
|
||||
def _compiler_dispatch(cls, _visitor, **_kw):
|
||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||
return "MAP"
|
||||
|
||||
|
||||
|
@ -79,11 +81,11 @@ class Row(TypeEngine):
|
|||
A type for rows.
|
||||
"""
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> Optional[Type]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _compiler_dispatch(cls, _visitor, **_kw):
|
||||
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
|
||||
return "ROW"
|
||||
|
||||
|
||||
|
|
|
@ -17,15 +17,23 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import relationship, Session, sessionmaker
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import FavStar # pylint: disable=unused-import
|
||||
from superset.models.dashboard import Dashboard # pylint: disable=unused-import
|
||||
from superset.models.slice import Slice # pylint: disable=unused-import
|
||||
from superset.models.sql_lab import Query # pylint: disable=unused-import
|
||||
|
||||
Session = sessionmaker(autoflush=False)
|
||||
|
||||
|
||||
|
@ -80,7 +88,7 @@ class TaggedObject(Model, AuditMixinNullable):
|
|||
tag = relationship("Tag", backref="objects")
|
||||
|
||||
|
||||
def get_tag(name, session, type_):
|
||||
def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
|
||||
try:
|
||||
tag = session.query(Tag).filter_by(name=name, type=type_).one()
|
||||
except NoResultFound:
|
||||
|
@ -91,7 +99,7 @@ def get_tag(name, session, type_):
|
|||
return tag
|
||||
|
||||
|
||||
def get_object_type(class_name):
|
||||
def get_object_type(class_name: str) -> ObjectTypes:
|
||||
mapping = {
|
||||
"slice": ObjectTypes.chart,
|
||||
"dashboard": ObjectTypes.dashboard,
|
||||
|
@ -108,11 +116,15 @@ class ObjectUpdater:
|
|||
object_type: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_owners_ids(cls, target):
|
||||
def get_owners_ids(
|
||||
cls, target: Union["Dashboard", "FavStar", "Slice"]
|
||||
) -> List[int]:
|
||||
raise NotImplementedError("Subclass should implement `get_owners_ids`")
|
||||
|
||||
@classmethod
|
||||
def _add_owners(cls, session, target):
|
||||
def _add_owners(
|
||||
cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
|
||||
) -> None:
|
||||
for owner_id in cls.get_owners_ids(target):
|
||||
name = "owner:{0}".format(owner_id)
|
||||
tag = get_tag(name, session, TagTypes.owner)
|
||||
|
@ -122,7 +134,12 @@ class ObjectUpdater:
|
|||
session.add(tagged_object)
|
||||
|
||||
@classmethod
|
||||
def after_insert(cls, mapper, connection, target):
|
||||
def after_insert(
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
|
@ -139,7 +156,12 @@ class ObjectUpdater:
|
|||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_update(cls, mapper, connection, target):
|
||||
def after_update(
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
|
@ -164,7 +186,12 @@ class ObjectUpdater:
|
|||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_delete(cls, mapper, connection, target):
|
||||
def after_delete(
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
|
@ -182,7 +209,7 @@ class ChartUpdater(ObjectUpdater):
|
|||
object_type = "chart"
|
||||
|
||||
@classmethod
|
||||
def get_owners_ids(cls, target):
|
||||
def get_owners_ids(cls, target: "Slice") -> List[int]:
|
||||
return [owner.id for owner in target.owners]
|
||||
|
||||
|
||||
|
@ -191,7 +218,7 @@ class DashboardUpdater(ObjectUpdater):
|
|||
object_type = "dashboard"
|
||||
|
||||
@classmethod
|
||||
def get_owners_ids(cls, target):
|
||||
def get_owners_ids(cls, target: "Dashboard") -> List[int]:
|
||||
return [owner.id for owner in target.owners]
|
||||
|
||||
|
||||
|
@ -200,13 +227,15 @@ class QueryUpdater(ObjectUpdater):
|
|||
object_type = "query"
|
||||
|
||||
@classmethod
|
||||
def get_owners_ids(cls, target):
|
||||
def get_owners_ids(cls, target: "Query") -> List[int]:
|
||||
return [target.user_id]
|
||||
|
||||
|
||||
class FavStarUpdater:
|
||||
@classmethod
|
||||
def after_insert(cls, mapper, connection, target):
|
||||
def after_insert(
|
||||
cls, mapper: Mapper, connection: Connection, target: "FavStar"
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
name = "favorited_by:{0}".format(target.user_id)
|
||||
|
@ -221,7 +250,9 @@ class FavStarUpdater:
|
|||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_delete(cls, mapper, connection, target):
|
||||
def after_delete(
|
||||
cls, mapper: Mapper, connection: Connection, target: "FavStar"
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
name = "favorited_by:{0}".format(target.user_id)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Callable, Optional
|
||||
|
||||
from flask import request
|
||||
|
||||
from superset.extensions import cache_manager
|
||||
|
@ -24,7 +26,9 @@ def view_cache_key(*_, **__) -> str:
|
|||
return "view/{}/{}".format(request.path, args_hash)
|
||||
|
||||
|
||||
def memoized_func(key=view_cache_key, attribute_in_key=None):
|
||||
def memoized_func(
|
||||
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
|
||||
) -> Callable:
|
||||
"""Use this decorator to cache functions that have predefined first arg.
|
||||
|
||||
enable_cache is treated as True by default,
|
||||
|
|
|
@ -143,7 +143,7 @@ class _memoized:
|
|||
self.func = func
|
||||
self.cache = {}
|
||||
self.is_method = False
|
||||
self.watch = watch
|
||||
self.watch = watch or []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
key = [args, frozenset(kwargs.items())]
|
||||
|
@ -172,7 +172,7 @@ class _memoized:
|
|||
return functools.partial(self.__call__, obj)
|
||||
|
||||
|
||||
def memoized(func=None, watch=None):
|
||||
def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None):
|
||||
if func:
|
||||
return _memoized(func)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue