[mypy] Enforcing typing for superset.models (#9883)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-05-22 20:31:21 -07:00 committed by GitHub
parent 6d4e23663e
commit e789a35558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 207 additions and 130 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -18,7 +18,7 @@
import logging
import re
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
import pandas as pd
@ -103,7 +103,11 @@ class AnnotationDatasource(BaseDatasource):
logger.exception(ex)
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
status=status, df=df, duration=0, query="", error_message=error_message
status=status,
df=df,
duration=timedelta(0),
query="",
error_message=error_message,
)
def get_query_str(self, query_obj):

View File

@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Code related with dealing with legacy / change management"""
from typing import Any, Dict
def update_time_range(form_data):
def update_time_range(form_data: Dict[str, Any]) -> None:
"""Move since and until to time_range."""
if "since" in form_data or "until" in form_data:
form_data["time_range"] = "{} : {}".format(

View File

@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""a collection of Annotation-related models"""
from typing import Any, Dict
from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text
from sqlalchemy.orm import relationship
@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable):
name = Column(String(250))
descr = Column(Text)
def __repr__(self):
def __repr__(self) -> str:
return self.name
@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
__table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)
@property
def data(self):
def data(self) -> Dict[str, Any]:
return {
"layer_id": self.layer_id,
"start_dttm": self.start_dttm,

View File

@ -152,7 +152,7 @@ class Database(
]
export_children = ["tables"]
def __repr__(self):
def __repr__(self) -> str:
return self.name
@property
@ -234,7 +234,9 @@ class Database(
return self.get_extra().get("default_schemas", [])
@classmethod
def get_password_masked_url_from_uri(cls, uri: str): # pylint: disable=invalid-name
def get_password_masked_url_from_uri( # pylint: disable=invalid-name
cls, uri: str
) -> URL:
sqlalchemy_url = make_url(uri)
return cls.get_password_masked_url(sqlalchemy_url)
@ -279,7 +281,7 @@ class Database(
effective_username = g.user.username
return effective_username
@utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
@utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted", "extra"])
def get_sqla_engine(
self,
schema: Optional[str] = None,
@ -339,7 +341,7 @@ class Database(
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words
def get_quoter(self):
def get_quoter(self) -> Callable:
return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals
@ -405,7 +407,7 @@ class Database(
indent: bool = True,
latest_partition: bool = False,
cols: Optional[List[Dict[str, Any]]] = None,
):
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
return self.db_engine_spec.select_star(
@ -436,7 +438,10 @@ class Database(
attribute_in_key="id",
)
def get_all_table_names_in_database(
self, cache: bool = False, cache_timeout: Optional[bool] = None, force=False
self,
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
@ -547,7 +552,7 @@ class Database(
@classmethod
def get_db_engine_spec_for_backend(
cls, backend
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
@ -565,7 +570,7 @@ class Database(
def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)
def get_encrypted_extra(self):
def get_encrypted_extra(self) -> Dict[str, Any]:
encrypted_extra = {}
if self.encrypted_extra:
try:

View File

@ -36,7 +36,9 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from sqlalchemy.orm.mapper import Mapper
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
@ -59,7 +61,7 @@ config = app.config
logger = logging.getLogger(__name__)
def copy_dashboard(mapper, connection, target):
def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
# pylint: disable=unused-argument
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
@ -140,7 +142,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
"slug",
]
def __repr__(self):
def __repr__(self) -> str:
return self.dashboard_title or str(self.id)
@property
@ -202,13 +204,13 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"
@property
def changed_by_name(self):
def changed_by_name(self) -> str:
if not self.changed_by:
return ""
return str(self.changed_by)
@property
def changed_by_url(self):
def changed_by_url(self) -> str:
if not self.changed_by:
return ""
return f"/superset/profile/{self.changed_by.username}"
@ -229,8 +231,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
"position_json": positions,
}
@property
def params(self) -> str:
@property # type: ignore
def params(self) -> str: # type: ignore
return self.json_metadata
@params.setter
@ -257,7 +259,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
Audit metadata isn't copied over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
def alter_positions(
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
) -> None:
""" Updates slice_ids in the position json.
Sample position_json data:
@ -291,9 +295,9 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
if (
isinstance(value, dict)
and value.get("meta")
and value.get("meta").get("chartId")
and value.get("meta", {}).get("chartId")
):
old_slice_id = value.get("meta").get("chartId")
old_slice_id = value["meta"]["chartId"]
if old_slice_id in old_to_new_slc_id_dict:
value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id]
@ -470,8 +474,8 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
def event_after_dashboard_changed( # pylint: disable=unused-argument
mapper, connection, target
):
mapper: Mapper, connection: Connection, target: Dashboard
) -> None:
cache_dashboard_thumbnail.delay(target.id, force=True)

View File

@ -18,8 +18,8 @@
import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Union
# isort and pylint disagree, isort should win
# pylint: disable=ungrouped-imports
@ -30,8 +30,10 @@ import yaml
from flask import escape, g, Markup
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from superset.utils.core import QueryStatus
@ -39,7 +41,7 @@ from superset.utils.core import QueryStatus
logger = logging.getLogger(__name__)
def json_to_dict(json_str):
def json_to_dict(json_str: str) -> Dict[Any, Any]:
if json_str:
val = re.sub(",[ \t\r\n]+}", "}", json_str)
val = re.sub(
@ -64,48 +66,56 @@ class ImportMixin:
# that are available for import and export
@classmethod
def _parent_foreign_key_mappings(cls):
def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
"""Get a mapping of foreign name to the local name of foreign keys"""
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
parent_rel = cls.__mapper__.relationships.get(cls.export_parent) # type: ignore
if parent_rel:
return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
return {}
@classmethod
def _unique_constrains(cls):
def _unique_constrains(cls) -> List[Set[str]]:
"""Get all (single column and multi column) unique constraints"""
unique = [
{c.name for c in u.columns}
for u in cls.__table_args__
for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint)
]
unique.extend({c.name} for c in cls.__table__.columns if c.unique)
unique.extend( # type: ignore
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
)
return unique
@classmethod
def export_schema(cls, recursive=True, include_parent_ref=False):
def export_schema(
cls, recursive: bool = True, include_parent_ref: bool = False
) -> Dict[str, Any]:
"""Export schema as a dictionary"""
parent_excludes = {}
parent_excludes = set()
if not include_parent_ref:
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
parent_ref = cls.__mapper__.relationships.get( # type: ignore
cls.export_parent
)
if parent_ref:
parent_excludes = {column.name for column in parent_ref.local_columns}
def formatter(column):
def formatter(column: sa.Column) -> str:
return (
"{0} Default ({1})".format(str(column.type), column.default.arg)
if column.default
else str(column.type)
)
schema = {
schema: Dict[str, Any] = {
column.name: formatter(column)
for column in cls.__table__.columns
for column in cls.__table__.columns # type: ignore
if (column.name in cls.export_fields and column.name not in parent_excludes)
}
if recursive:
for column in cls.export_children:
child_class = cls.__mapper__.relationships[column].argument.class_
child_class = cls.__mapper__.relationships[ # type: ignore
column
].argument.class_
schema[column] = [
child_class.export_schema(
recursive=recursive, include_parent_ref=include_parent_ref
@ -114,17 +124,20 @@ class ImportMixin:
return schema
@classmethod
def import_from_dict(
cls, session, dict_rep, parent=None, recursive=True, sync=None
): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
session: Session,
dict_rep: Dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
sync: Optional[List[str]] = None,
) -> Any: # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
"""Import obj from a dictionary"""
if sync is None:
sync = []
parent_refs = cls._parent_foreign_key_mappings()
export_fields = set(cls.export_fields) | set(parent_refs.keys())
new_children = {
c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
}
new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep}
unique_constrains = cls._unique_constrains()
filters = [] # Using these filters to check if obj already exists
@ -178,7 +191,7 @@ class ImportMixin:
if not obj:
is_new_obj = True
# Create new DB object
obj = cls(**dict_rep)
obj = cls(**dict_rep) # type: ignore
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
@ -193,7 +206,9 @@ class ImportMixin:
# Recursively create children
if recursive:
for child in cls.export_children:
child_class = cls.__mapper__.relationships[child].argument.class_
child_class = cls.__mapper__.relationships[ # type: ignore
child
].argument.class_
added = []
for c_obj in new_children.get(child, []):
added.append(
@ -221,18 +236,23 @@ class ImportMixin:
return obj
def export_to_dict(
self, recursive=True, include_parent_ref=False, include_defaults=False
):
self,
recursive: bool = True,
include_parent_ref: bool = False,
include_defaults: bool = False,
) -> Dict[Any, Any]:
"""Export obj to dictionary"""
cls = self.__class__
parent_excludes = {}
parent_excludes = set()
if recursive and not include_parent_ref:
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
parent_ref = cls.__mapper__.relationships.get( # type: ignore
cls.export_parent
)
if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = {
c.name: getattr(self, c.name)
for c in cls.__table__.columns
for c in cls.__table__.columns # type: ignore
if (
c.name in self.export_fields
and c.name not in parent_excludes
@ -262,18 +282,18 @@ class ImportMixin:
return dict_rep
def override(self, obj):
def override(self, obj: Any) -> None:
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
def copy(self):
def copy(self) -> Any:
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
def alter_params(self, **kwargs):
def alter_params(self, **kwargs: Any) -> None:
params = self.params_dict
params.update(kwargs)
self.params = json.dumps(params)
@ -283,7 +303,7 @@ class ImportMixin:
params.pop(param_to_remove, None)
self.params = json.dumps(params)
def reset_ownership(self):
def reset_ownership(self) -> None:
""" object will belong to the user the current user """
# make sure the object doesn't have relations to a user
# it will be filled by appbuilder on save
@ -297,15 +317,15 @@ class ImportMixin:
self.owners = []
@property
def params_dict(self):
def params_dict(self) -> Dict[Any, Any]:
return json_to_dict(self.params)
@property
def template_params_dict(self):
return json_to_dict(self.template_params)
def template_params_dict(self) -> Dict[Any, Any]:
return json_to_dict(self.template_params) # type: ignore
def _user_link(user): # pylint: disable=no-self-use
def _user_link(user: User) -> Union[Markup, str]: # pylint: disable=no-self-use
if not user:
return ""
url = "/superset/profile/{}/".format(user.username)
@ -325,7 +345,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
def created_by_fk(self):
def created_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
@ -334,7 +354,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
def changed_by_fk(self):
def changed_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
@ -343,29 +363,29 @@ class AuditMixinNullable(AuditMixin):
nullable=True,
)
def changed_by_name(self):
def changed_by_name(self) -> str:
if self.created_by:
return escape("{}".format(self.created_by))
return ""
@renders("created_by")
def creator(self):
def creator(self) -> Union[Markup, str]:
return _user_link(self.created_by)
@property
def changed_by_(self):
def changed_by_(self) -> Union[Markup, str]:
return _user_link(self.changed_by)
@renders("changed_on")
def changed_on_(self):
def changed_on_(self) -> Markup:
return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
@property
def changed_on_humanized(self):
def changed_on_humanized(self) -> str:
return humanize.naturaltime(datetime.now() - self.changed_on)
@renders("changed_on")
def modified(self):
def modified(self) -> Markup:
return Markup(f'<span class="no-wrap">{self.changed_on_humanized}</span>')
@ -375,19 +395,19 @@ class QueryResult: # pylint: disable=too-few-public-methods
def __init__( # pylint: disable=too-many-arguments
self,
df,
query,
duration,
status=QueryStatus.SUCCESS,
error_message=None,
errors=None,
):
self.df: pd.DataFrame = df
self.query: str = query
self.duration: int = duration
self.status: str = status
self.error_message: Optional[str] = error_message
self.errors: List[Dict[str, Any]] = errors or []
df: pd.DataFrame,
query: str,
duration: timedelta,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[List[Dict[str, Any]]] = None,
) -> None:
self.df = df
self.query = query
self.duration = duration
self.status = status
self.error_message = error_message
self.errors = errors or []
class ExtraJSONMixin:
@ -396,16 +416,16 @@ class ExtraJSONMixin:
extra_json = sa.Column(sa.Text, default="{}")
@property
def extra(self):
def extra(self) -> Dict[str, Any]:
try:
return json.loads(self.extra_json)
except Exception: # pylint: disable=broad-except
return {}
def set_extra_json(self, extras):
def set_extra_json(self, extras: Dict[str, Any]) -> None:
self.extra_json = json.dumps(extras)
def set_extra_json_key(self, key, value):
def set_extra_json_key(self, key: str, value: Any) -> None:
extra = self.extra
extra[key] = value
self.extra_json = json.dumps(extra)

View File

@ -21,7 +21,7 @@ from typing import Optional, Type
from flask_appbuilder import Model
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, RelationshipProperty
from superset import security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
@ -55,11 +55,11 @@ class EmailSchedule:
crontab = Column(String(50))
@declared_attr
def user_id(self):
def user_id(self) -> int:
return Column(Integer, ForeignKey("ab_user.id"))
@declared_attr
def user(self):
def user(self) -> RelationshipProperty:
return relationship(
security_manager.user_model,
backref=self.__tablename__,

View File

@ -24,7 +24,9 @@ from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from markupsafe import escape, Markup
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import make_transient, relationship
from sqlalchemy.orm.mapper import Mapper
from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.legacy import update_time_range
@ -92,7 +94,7 @@ class Slice(
"cache_timeout",
]
def __repr__(self):
def __repr__(self) -> str:
return self.slice_name or str(self.id)
@property
@ -263,7 +265,7 @@ class Slice(
@property
def changed_by_url(self) -> str:
return f"/superset/profile/{self.created_by.username}"
return f"/superset/profile/{self.created_by.username}" # type: ignore
@property
def icons(self) -> str:
@ -324,7 +326,7 @@ class Slice(
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
def set_related_perm(mapper, connection, target):
def set_related_perm(mapper: Mapper, connection: Connection, target: Slice) -> None:
# pylint: disable=unused-argument
src_class = target.cls_model
id_ = target.datasource_id
@ -336,8 +338,8 @@ def set_related_perm(mapper, connection, target):
def event_after_chart_changed( # pylint: disable=unused-argument
mapper, connection, target
):
mapper: Mapper, connection: Connection, target: Slice
) -> None:
cache_chart_thumbnail.delay(target.id, force=True)

View File

@ -17,6 +17,7 @@
"""A collection of ORM sqlalchemy models for SQL Lab"""
import re
from datetime import datetime
from typing import Any, Dict
# pylint: disable=ungrouped-imports
import simplejson as json
@ -33,6 +34,7 @@ from sqlalchemy import (
String,
Text,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import backref, relationship
from superset import security_manager
@ -99,7 +101,7 @@ class Query(Model, ExtraJSONMixin):
__table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"changedOn": self.changed_on,
"changed_on": self.changed_on.isoformat(),
@ -130,7 +132,7 @@ class Query(Model, ExtraJSONMixin):
}
@property
def name(self):
def name(self) -> str:
"""Name property"""
ts = datetime.now().isoformat()
ts = ts.replace("-", "").replace(":", "").split(".")[0]
@ -139,11 +141,11 @@ class Query(Model, ExtraJSONMixin):
return f"sqllab_{tab}_{ts}"
@property
def database_name(self):
def database_name(self) -> str:
return self.database.name
@property
def username(self):
def username(self) -> str:
return self.user.username
@ -170,7 +172,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
)
@property
def pop_tab_link(self):
def pop_tab_link(self) -> Markup:
return Markup(
f"""
<a href="/superset/sqllab?savedQueryId={self.id}">
@ -180,14 +182,14 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
)
@property
def user_email(self):
def user_email(self) -> str:
return self.user.email
@property
def sqlalchemy_uri(self):
def sqlalchemy_uri(self) -> URL:
return self.database.sqlalchemy_uri
def url(self):
def url(self) -> str:
return "/superset/sqllab?savedQueryId={0}".format(self.id)
@ -226,7 +228,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin):
autorun = Column(Boolean, default=False)
template_params = Column(Text)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"user_id": self.user_id,
@ -260,7 +262,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin):
expanded = Column(Boolean, default=False)
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
try:
description = json.loads(self.description)
except json.JSONDecodeError:

View File

@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
# _compiler_dispatch is defined to help with type compilation
@ -27,11 +29,11 @@ class TinyInteger(Integer):
A type for tiny ``int`` integers.
"""
def python_type(self):
def python_type(self) -> Type:
return int
@classmethod
def _compiler_dispatch(cls, _visitor, **_kw):
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "TINYINT"
@ -40,11 +42,11 @@ class Interval(TypeEngine):
A type for intervals.
"""
def python_type(self):
def python_type(self) -> Optional[Type]:
return None
@classmethod
def _compiler_dispatch(cls, _visitor, **_kw):
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "INTERVAL"
@ -53,11 +55,11 @@ class Array(TypeEngine):
A type for arrays.
"""
def python_type(self):
def python_type(self) -> Optional[Type]:
return list
@classmethod
def _compiler_dispatch(cls, _visitor, **_kw):
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ARRAY"
@ -66,11 +68,11 @@ class Map(TypeEngine):
A type for maps.
"""
def python_type(self):
def python_type(self) -> Optional[Type]:
return dict
@classmethod
def _compiler_dispatch(cls, _visitor, **_kw):
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "MAP"
@ -79,11 +81,11 @@ class Row(TypeEngine):
A type for rows.
"""
def python_type(self):
def python_type(self) -> Optional[Type]:
return None
@classmethod
def _compiler_dispatch(cls, _visitor, **_kw):
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"

View File

@ -17,15 +17,23 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import enum
from typing import Optional
from typing import List, Optional, TYPE_CHECKING, Union
from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, Session, sessionmaker
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.mapper import Mapper
from superset.models.helpers import AuditMixinNullable
if TYPE_CHECKING:
from superset.models.core import FavStar # pylint: disable=unused-import
from superset.models.dashboard import Dashboard # pylint: disable=unused-import
from superset.models.slice import Slice # pylint: disable=unused-import
from superset.models.sql_lab import Query # pylint: disable=unused-import
Session = sessionmaker(autoflush=False)
@ -80,7 +88,7 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", backref="objects")
def get_tag(name, session, type_):
def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
try:
tag = session.query(Tag).filter_by(name=name, type=type_).one()
except NoResultFound:
@ -91,7 +99,7 @@ def get_tag(name, session, type_):
return tag
def get_object_type(class_name):
def get_object_type(class_name: str) -> ObjectTypes:
mapping = {
"slice": ObjectTypes.chart,
"dashboard": ObjectTypes.dashboard,
@ -108,11 +116,15 @@ class ObjectUpdater:
object_type: Optional[str] = None
@classmethod
def get_owners_ids(cls, target):
def get_owners_ids(
cls, target: Union["Dashboard", "FavStar", "Slice"]
) -> List[int]:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
def _add_owners(cls, session, target):
def _add_owners(
cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
) -> None:
for owner_id in cls.get_owners_ids(target):
name = "owner:{0}".format(owner_id)
tag = get_tag(name, session, TagTypes.owner)
@ -122,7 +134,12 @@ class ObjectUpdater:
session.add(tagged_object)
@classmethod
def after_insert(cls, mapper, connection, target):
def after_insert(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@ -139,7 +156,12 @@ class ObjectUpdater:
session.commit()
@classmethod
def after_update(cls, mapper, connection, target):
def after_update(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@ -164,7 +186,12 @@ class ObjectUpdater:
session.commit()
@classmethod
def after_delete(cls, mapper, connection, target):
def after_delete(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@ -182,7 +209,7 @@ class ChartUpdater(ObjectUpdater):
object_type = "chart"
@classmethod
def get_owners_ids(cls, target):
def get_owners_ids(cls, target: "Slice") -> List[int]:
return [owner.id for owner in target.owners]
@ -191,7 +218,7 @@ class DashboardUpdater(ObjectUpdater):
object_type = "dashboard"
@classmethod
def get_owners_ids(cls, target):
def get_owners_ids(cls, target: "Dashboard") -> List[int]:
return [owner.id for owner in target.owners]
@ -200,13 +227,15 @@ class QueryUpdater(ObjectUpdater):
object_type = "query"
@classmethod
def get_owners_ids(cls, target):
def get_owners_ids(cls, target: "Query") -> List[int]:
return [target.user_id]
class FavStarUpdater:
@classmethod
def after_insert(cls, mapper, connection, target):
def after_insert(
cls, mapper: Mapper, connection: Connection, target: "FavStar"
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
@ -221,7 +250,9 @@ class FavStarUpdater:
session.commit()
@classmethod
def after_delete(cls, mapper, connection, target):
def after_delete(
cls, mapper: Mapper, connection: Connection, target: "FavStar"
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)

View File

@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable, Optional
from flask import request
from superset.extensions import cache_manager
@ -24,7 +26,9 @@ def view_cache_key(*_, **__) -> str:
return "view/{}/{}".format(request.path, args_hash)
def memoized_func(key=view_cache_key, attribute_in_key=None):
def memoized_func(
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
) -> Callable:
"""Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default,

View File

@ -143,7 +143,7 @@ class _memoized:
self.func = func
self.cache = {}
self.is_method = False
self.watch = watch
self.watch = watch or []
def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
@ -172,7 +172,7 @@ class _memoized:
return functools.partial(self.__call__, obj)
def memoized(func=None, watch=None):
def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None):
if func:
return _memoized(func)
else: