[mypy] Enforcing typing for a number of modules (#9586)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-04-24 10:07:35 -07:00 committed by GitHub
parent 7d5f4494d0
commit 1c656feb95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 121 additions and 79 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*]
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -823,6 +823,7 @@ class DruidDatasource(Model, BaseDatasource):
if origin:
dttm = utils.parse_human_datetime(origin)
assert dttm
granularity["origin"] = dttm.isoformat()
if period_name in iso_8601_dict:
@ -978,6 +979,7 @@ class DruidDatasource(Model, BaseDatasource):
# TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid
if self.fetch_values_from:
from_dttm = utils.parse_human_datetime(self.fetch_values_from)
assert from_dttm
else:
from_dttm = datetime(1970, 1, 1)

View File

@ -41,7 +41,7 @@ class SupersetTimeoutException(SupersetException):
class SupersetSecurityException(SupersetException):
status = 401
def __init__(self, msg, link=None):
def __init__(self, msg: str, link: Optional[str] = None) -> None:
super(SupersetSecurityException, self).__init__(msg)
self.link = link

View File

@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Models for scheduled execution of jobs"""
import enum
from typing import Optional, Type
from flask_appbuilder import Model
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
@ -86,9 +86,9 @@ class SliceEmailSchedule(Model, AuditMixinNullable, ImportMixin, EmailSchedule):
email_format = Column(Enum(SliceEmailReportFormat))
def get_scheduler_model(report_type):
if report_type == ScheduleType.dashboard.value:
def get_scheduler_model(report_type: ScheduleType) -> Optional[Type[EmailSchedule]]:
if report_type == ScheduleType.dashboard:
return DashboardEmailSchedule
elif report_type == ScheduleType.slice.value:
elif report_type == ScheduleType.slice:
return SliceEmailSchedule
return None

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy.engine.url import URL
class DBSecurityException(Exception):
@ -22,7 +23,7 @@ class DBSecurityException(Exception):
status = 400
def check_sqlalchemy_uri(uri):
def check_sqlalchemy_uri(uri: URL) -> None:
if uri.startswith("sqlite"):
# sqlite creates a local DB, which allows mapping server's filesystem
raise DBSecurityException(

View File

@ -38,6 +38,7 @@ from flask_appbuilder.widgets import ListWidget
from sqlalchemy import or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query
from superset import sql_parse
from superset.connectors.connector_registry import ConnectorRegistry
@ -70,7 +71,7 @@ class SupersetRoleListWidget(ListWidget):
template = "superset/fab_overrides/list_role.html"
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
kwargs["appbuilder"] = current_app.appbuilder
super().__init__(**kwargs)
@ -580,7 +581,7 @@ class SupersetSecurityManager(SecurityManager):
if pv.permission and pv.view_menu:
all_pvs.add((pv.permission.name, pv.view_menu.name))
def merge_pv(view_menu, perm):
def merge_pv(view_menu: str, perm: str) -> None:
"""Create permission view menu only if it doesn't exist"""
if view_menu and perm and (view_menu, perm) not in all_pvs:
self.add_permission_view_menu(view_menu, perm)
@ -899,7 +900,7 @@ class SupersetSecurityManager(SecurityManager):
self.assert_datasource_permission(viz.datasource)
def get_rls_filters(self, table: "BaseDatasource"):
def get_rls_filters(self, table: "BaseDatasource") -> List[Query]:
"""
Retrieves the appropriate row level security filters for the current user and the passed table.

View File

@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional
from flask import g
from superset import app, security_manager
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
@ -44,7 +45,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls, statement, database, cursor, user_name
cls, statement: str, database: Database, cursor: Any, user_name: str
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec

View File

@ -18,7 +18,7 @@
import json
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union
from urllib import request
from urllib.error import URLError
@ -38,7 +38,9 @@ logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
def get_form_data(chart_id, dashboard=None):
def get_form_data(
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.
@ -46,7 +48,7 @@ def get_form_data(chart_id, dashboard=None):
filters in the GET request for charts.
"""
form_data = {"slice_id": chart_id}
form_data: Dict[str, Any] = {"slice_id": chart_id}
if dashboard is None or not dashboard.json_metadata:
return form_data
@ -72,7 +74,7 @@ def get_form_data(chart_id, dashboard=None):
return form_data
def get_url(chart, extra_filters: Optional[Dict[str, Any]] = None):
def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
"""Return external URL for warming up a given chart/table cache."""
with app.test_request_context():
baseurl = (
@ -106,10 +108,10 @@ class Strategy:
"""
def __init__(self):
def __init__(self) -> None:
pass
def get_urls(self):
def get_urls(self) -> List[str]:
raise NotImplementedError("Subclasses must implement get_urls!")
@ -131,7 +133,7 @@ class DummyStrategy(Strategy):
name = "dummy"
def get_urls(self):
def get_urls(self) -> List[str]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
@ -158,12 +160,12 @@ class TopNDashboardsStrategy(Strategy):
name = "top_n_dashboards"
def __init__(self, top_n=5, since="7 days ago"):
def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None:
super(TopNDashboardsStrategy, self).__init__()
self.top_n = top_n
self.since = parse_human_datetime(since)
def get_urls(self):
def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()
@ -203,11 +205,11 @@ class DashboardTagsStrategy(Strategy):
name = "dashboard_tags"
def __init__(self, tags=None):
def __init__(self, tags: Optional[List[str]] = None) -> None:
super(DashboardTagsStrategy, self).__init__()
self.tags = tags or []
def get_urls(self):
def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()
@ -254,7 +256,9 @@ strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="cache-warmup")
def cache_warmup(strategy_name, *args, **kwargs):
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
) -> Union[Dict[str, List[str]], str]:
"""
Warm up cache.
@ -264,7 +268,7 @@ def cache_warmup(strategy_name, *args, **kwargs):
logger.info("Loading strategy")
class_ = None
for class_ in strategies:
if class_.name == strategy_name:
if class_.name == strategy_name: # type: ignore
break
else:
message = f"No strategy {strategy_name} found!"
@ -280,7 +284,7 @@ def cache_warmup(strategy_name, *args, **kwargs):
logger.exception(message)
return message
results = {"success": [], "errors": []}
results: Dict[str, List[str]] = {"success": [], "errors": []}
for url in strategy.get_urls():
try:
logger.info(f"Fetching {url}")

View File

@ -25,7 +25,7 @@ from superset import create_app
from superset.extensions import celery_app
# Init the Flask app / configure everything
create_app()
create_app() # type: ignore
# Need to import late, as the celery_app will have been setup by "create_app()"
# pylint: disable=wrong-import-position, unused-import

View File

@ -23,10 +23,12 @@ import urllib.request
from collections import namedtuple
from datetime import datetime, timedelta
from email.utils import make_msgid, parseaddr
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from urllib.error import URLError # pylint: disable=ungrouped-imports
import croniter
import simplejson as json
from celery.app.task import Task
from dateutil.tz import tzlocal
from flask import render_template, Response, session, url_for
from flask_babel import gettext as __
@ -34,16 +36,20 @@ from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from werkzeug.datastructures import TypeConversionDict
from werkzeug.http import parse_cookie
# Superset framework imports
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.schedules import (
DashboardEmailSchedule,
EmailDeliveryType,
EmailSchedule,
get_scheduler_model,
ScheduleType,
SliceEmailReportFormat,
SliceEmailSchedule,
)
from superset.utils.core import get_email_address_list, send_email_smtp
@ -59,7 +65,9 @@ PAGE_RENDER_WAIT = 30
EmailContent = namedtuple("EmailContent", ["body", "data", "images"])
def _get_recipients(schedule):
def _get_recipients(
schedule: Union[DashboardEmailSchedule, SliceEmailSchedule]
) -> Iterator[Tuple[str, str]]:
bcc = config["EMAIL_REPORT_BCC_ADDRESS"]
if schedule.deliver_as_group:
@ -70,7 +78,11 @@ def _get_recipients(schedule):
yield (to, bcc)
def _deliver_email(schedule, subject, email):
def _deliver_email(
schedule: Union[DashboardEmailSchedule, SliceEmailSchedule],
subject: str,
email: EmailContent,
) -> None:
for (to, bcc) in _get_recipients(schedule):
send_email_smtp(
to,
@ -85,7 +97,11 @@ def _deliver_email(schedule, subject, email):
)
def _generate_mail_content(schedule, screenshot, name, url):
def _generate_mail_content(
schedule: EmailSchedule, screenshot: bytes, name: str, url: str
) -> EmailContent:
data: Optional[Dict[str, Any]]
if schedule.delivery_type == EmailDeliveryType.attachment:
images = None
data = {"screenshot.png": screenshot}
@ -115,7 +131,7 @@ def _generate_mail_content(schedule, screenshot, name, url):
return EmailContent(body, data, images)
def _get_auth_cookies():
def _get_auth_cookies() -> List[TypeConversionDict]:
# Login with the user specified to get the reports
with app.test_request_context():
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
@ -136,14 +152,16 @@ def _get_auth_cookies():
return cookies
def _get_url_path(view, **kwargs):
def _get_url_path(view: str, **kwargs: Any) -> str:
with app.test_request_context():
return urllib.parse.urljoin(
str(config["WEBDRIVER_BASEURL"]), url_for(view, **kwargs)
)
def create_webdriver():
def create_webdriver() -> Union[
chrome.webdriver.WebDriver, firefox.webdriver.WebDriver
]:
# Create a webdriver for use in fetching reports
if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox":
driver_class = firefox.webdriver.WebDriver
@ -181,7 +199,9 @@ def create_webdriver():
return driver
def destroy_webdriver(driver):
def destroy_webdriver(
driver: Union[chrome.webdriver.WebDriver, firefox.webdriver.WebDriver]
) -> None:
"""
Destroy a driver
"""
@ -198,7 +218,7 @@ def destroy_webdriver(driver):
pass
def deliver_dashboard(schedule):
def deliver_dashboard(schedule: DashboardEmailSchedule) -> None:
"""
Given a schedule, delivery the dashboard as an email report
"""
@ -243,7 +263,7 @@ def deliver_dashboard(schedule):
_deliver_email(schedule, subject, email)
def _get_slice_data(schedule):
def _get_slice_data(schedule: SliceEmailSchedule) -> EmailContent:
slc = schedule.slice
slice_url = _get_url_path(
@ -272,7 +292,7 @@ def _get_slice_data(schedule):
# Parse the csv file and generate HTML
columns = rows.pop(0)
with app.app_context():
with app.app_context(): # type: ignore
body = render_template(
"superset/reports/slice_data.html",
columns=columns,
@ -292,7 +312,7 @@ def _get_slice_data(schedule):
return EmailContent(body, data, None)
def _get_slice_visualization(schedule):
def _get_slice_visualization(schedule: SliceEmailSchedule) -> EmailContent:
slc = schedule.slice
# Create a driver, fetch the page, wait for the page to render
@ -327,7 +347,7 @@ def _get_slice_visualization(schedule):
return _generate_mail_content(schedule, screenshot, slc.slice_name, slice_url)
def deliver_slice(schedule):
def deliver_slice(schedule: Union[DashboardEmailSchedule, SliceEmailSchedule]) -> None:
"""
Given a schedule, delivery the slice as an email report
"""
@ -352,9 +372,12 @@ def deliver_slice(schedule):
bind=True,
soft_time_limit=config["EMAIL_ASYNC_TIME_LIMIT_SEC"],
)
def schedule_email_report(
task, report_type, schedule_id, recipients=None
): # pylint: disable=unused-argument
def schedule_email_report( # pylint: disable=unused-argument
task: Task,
report_type: ScheduleType,
schedule_id: int,
recipients: Optional[str] = None,
) -> None:
model_cls = get_scheduler_model(report_type)
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
@ -368,15 +391,17 @@ def schedule_email_report(
schedule.id = schedule_id
schedule.recipients = recipients
if report_type == ScheduleType.dashboard.value:
if report_type == ScheduleType.dashboard:
deliver_dashboard(schedule)
elif report_type == ScheduleType.slice.value:
elif report_type == ScheduleType.slice:
deliver_slice(schedule)
else:
raise RuntimeError("Unknown report type")
def next_schedules(crontab, start_at, stop_at, resolution=0):
def next_schedules(
crontab: str, start_at: datetime, stop_at: datetime, resolution: int = 0
) -> Iterator[datetime]:
crons = croniter.croniter(crontab, start_at - timedelta(seconds=1))
previous = start_at - timedelta(days=1)
@ -396,13 +421,19 @@ def next_schedules(crontab, start_at, stop_at, resolution=0):
previous = eta
def schedule_window(report_type, start_at, stop_at, resolution):
def schedule_window(
report_type: ScheduleType, start_at: datetime, stop_at: datetime, resolution: int
) -> None:
"""
Find all active schedules and schedule celery tasks for
each of them with a specific ETA (determined by parsing
the cron schedule for the schedule)
"""
model_cls = get_scheduler_model(report_type)
if not model_cls:
return None
dbsession = db.create_scoped_session()
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
@ -415,9 +446,11 @@ def schedule_window(report_type, start_at, stop_at, resolution):
):
schedule_email_report.apply_async(args, eta=eta)
return None
@celery_app.task(name="email_reports.schedule_hourly")
def schedule_hourly():
def schedule_hourly() -> None:
""" Celery beat job meant to be invoked hourly """
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
@ -429,5 +462,5 @@ def schedule_hourly():
# Get the top of the hour
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0)
stop_at = start_at + timedelta(seconds=3600)
schedule_window(ScheduleType.dashboard.value, start_at, stop_at, resolution)
schedule_window(ScheduleType.slice.value, start_at, stop_at, resolution)
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
schedule_window(ScheduleType.slice, start_at, stop_at, resolution)

View File

@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
def cache_chart_thumbnail(chart_id: int, force: bool = False):
with app.app_context():
def cache_chart_thumbnail(chart_id: int, force: bool = False) -> None:
with app.app_context(): # type: ignore
if not thumbnail_cache:
logger.warning("No cache set, refusing to compute")
return None
@ -42,8 +42,8 @@ def cache_chart_thumbnail(chart_id: int, force: bool = False):
@celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300)
def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False):
with app.app_context():
def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False) -> None:
with app.app_context(): # type: ignore
if not thumbnail_cache:
logging.warning("No cache set, refusing to compute")
return None

View File

@ -235,7 +235,7 @@ def list_minus(l: List, minus: List) -> List:
return [o for o in l if o not in minus]
def parse_human_datetime(s):
def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
"""
Returns ``datetime.datetime`` from human readable strings
@ -687,42 +687,42 @@ def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, conf
def send_email_smtp(
to,
subject,
html_content,
config,
files=None,
data=None,
images=None,
dryrun=False,
cc=None,
bcc=None,
mime_subtype="mixed",
):
to: str,
subject: str,
html_content: str,
config: Dict[str, Any],
files: Optional[List[str]] = None,
data: Optional[Dict[str, str]] = None,
images: Optional[Dict[str, str]] = None,
dryrun: bool = False,
cc: Optional[str] = None,
bcc: Optional[str] = None,
mime_subtype: str = "mixed",
) -> None:
"""
Send an email with html content, eg:
send_email_smtp(
'test@example.com', 'foo', '<b>Foo</b> bar',['/dev/null'], dryrun=True)
"""
smtp_mail_from = config["SMTP_MAIL_FROM"]
to = get_email_address_list(to)
smtp_mail_to = get_email_address_list(to)
msg = MIMEMultipart(mime_subtype)
msg["Subject"] = subject
msg["From"] = smtp_mail_from
msg["To"] = ", ".join(to)
msg["To"] = ", ".join(smtp_mail_to)
msg.preamble = "This is a multi-part message in MIME format."
recipients = to
recipients = smtp_mail_to
if cc:
cc = get_email_address_list(cc)
msg["CC"] = ", ".join(cc)
recipients = recipients + cc
smtp_mail_cc = get_email_address_list(cc)
msg["CC"] = ", ".join(smtp_mail_cc)
recipients = recipients + smtp_mail_cc
if bcc:
# don't add bcc in header
bcc = get_email_address_list(bcc)
recipients = recipients + bcc
smtp_mail_bcc = get_email_address_list(bcc)
recipients = recipients + smtp_mail_bcc
msg["Date"] = formatdate(localtime=True)
mime_text = MIMEText(html_content, "html")
@ -1034,8 +1034,8 @@ def get_since_until(
"""
separator = " : "
relative_start = parse_human_datetime(relative_start if relative_start else "today")
relative_end = parse_human_datetime(relative_end if relative_end else "today")
relative_start = parse_human_datetime(relative_start if relative_start else "today") # type: ignore
relative_end = parse_human_datetime(relative_end if relative_end else "today") # type: ignore
common_time_frames = {
"Last day": (
relative_start - relativedelta(days=1), # type: ignore
@ -1064,8 +1064,8 @@ def get_since_until(
since, until = time_range.split(separator, 1)
if since and since not in common_time_frames:
since = add_ago_to_since(since)
since = parse_human_datetime(since)
until = parse_human_datetime(until)
since = parse_human_datetime(since) # type: ignore
until = parse_human_datetime(until) # type: ignore
elif time_range in common_time_frames:
since, until = common_time_frames[time_range]
elif time_range == "No filter":
@ -1086,8 +1086,8 @@ def get_since_until(
since = since or ""
if since:
since = add_ago_to_since(since)
since = parse_human_datetime(since)
until = parse_human_datetime(until) if until else relative_end
since = parse_human_datetime(since) # type: ignore
until = parse_human_datetime(until) if until else relative_end # type: ignore
if time_shift:
time_delta = parse_past_timedelta(time_shift)