[thumbnails] API and celery task for dashboards and charts (#8947)

This commit is contained in:
Daniel Vaz Gaspar 2020-04-15 09:40:14 +01:00 committed by GitHub
parent 1ccda920fe
commit d81f720502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1141 additions and 13 deletions

View File

@ -647,6 +647,74 @@ section in `config.py`:
This will cache all the charts in the top 5 most popular dashboards every hour.
For other strategies, check the `superset/tasks/cache.py` file.
Caching Thumbnails
------------------
This is an optional feature that can be turned on by activating it's feature flag on config:
.. code-block:: python
FEATURE_FLAGS = {
"THUMBNAILS": True,
"THUMBNAILS_SQLA_LISTENERS": True,
}
For this feature you will need a cache system and celery workers. All thumbnails are store on cache and are processed
asynchronously by the workers.
An example config where images are stored on S3 could be:
.. code-block:: python
from flask import Flask
from s3cache.s3cache import S3Cache
...
class CeleryConfig(object):
BROKER_URL = "redis://localhost:6379/0"
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks", "superset.tasks.thumbnails")
CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
CELERYD_PREFETCH_MULTIPLIER = 10
CELERY_ACKS_LATE = True
CELERY_CONFIG = CeleryConfig
def init_thumbnail_cache(app: Flask) -> S3Cache:
return S3Cache("bucket_name", 'thumbs_cache/')
THUMBNAIL_CACHE_CONFIG = init_thumbnail_cache
# Async selenium thumbnail task will use the following user
THUMBNAIL_SELENIUM_USER = "Admin"
Using the above example cache keys for dashboards will be `superset_thumb__dashboard__{ID}`
You can override the base URL for selenium using:
.. code-block:: python
WEBDRIVER_BASEURL = "https://superset.company.com"
Additional selenium web drive config can be set using `WEBDRIVER_CONFIGURATION`
You can implement a custom function to authenticate selenium, the default uses flask-login session cookie.
An example of a custom function signature:
.. code-block:: python
def auth_driver(driver: WebDriver, user: "User") -> WebDriver:
pass
Then on config:
.. code-block:: python
WEBDRIVER_AUTH_FUNC = auth_driver
Deeper SQLAlchemy integration
-----------------------------

View File

@ -33,3 +33,4 @@ redis==3.2.1
requests==2.22.0
statsd==3.3.0
tox==3.11.1
pillow==7.0.0

View File

@ -118,6 +118,7 @@ setup(
"hana": ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"],
"dremio": ["sqlalchemy_dremio>=1.1.0"],
"cockroachdb": ["cockroachdb==0.3.3"],
"thumbnails": ["Pillow>=7.0.0, <8.0.0"],
},
python_requires="~=3.6",
author="Apache Software Foundation",

View File

@ -51,3 +51,4 @@ results_backend_use_msgpack = LocalProxy(
lambda: results_backend_manager.should_use_msgpack
)
tables_cache = LocalProxy(lambda: cache_manager.tables_cache)
thumbnail_cache = LocalProxy(lambda: cache_manager.thumbnail_cache)

View File

@ -15,14 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from typing import Any, Dict
import simplejson
from flask import g, make_response, request, Response
from flask import g, make_response, redirect, request, Response, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset.charts.commands.bulk_delete import BulkDeleteChartCommand
from superset.charts.commands.create import CreateChartCommand
from superset.charts.commands.delete import DeleteChartCommand
@ -41,13 +44,16 @@ from superset.charts.schemas import (
ChartPostSchema,
ChartPutSchema,
get_delete_ids_schema,
thumbnail_query_schema,
)
from superset.common.query_context import QueryContext
from superset.constants import RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger, security_manager
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils.core import json_int_dttm_ser
from superset.utils.screenshots import ChartScreenshot
from superset.views.base_api import BaseSupersetModelRestApi, RelatedFieldFilter
from superset.views.filters import FilterRelatedOwners
@ -131,6 +137,11 @@ class ChartRestApi(BaseSupersetModelRestApi):
}
allowed_rel_fields = {"owners"}
def __init__(self) -> None:
if is_feature_enabled("THUMBNAILS"):
self.include_route_methods = self.include_route_methods | {"thumbnail"}
super().__init__()
@expose("/", methods=["POST"])
@protect()
@safe
@ -440,13 +451,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
type: object
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
@ -464,3 +471,65 @@ class ChartRestApi(BaseSupersetModelRestApi):
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
@expose("/<pk>/thumbnail/<digest>/", methods=["GET"])
@protect()
@rison(thumbnail_query_schema)
@safe
def thumbnail(
self, pk: int, digest: str, **kwargs: Dict[str, bool]
) -> WerkzeugResponse:
"""Get Chart thumbnail
---
get:
description: Compute or get already computed chart thumbnail from cache
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
schema:
type: string
name: sha
responses:
200:
description: Chart thumbnail image
content:
image/*:
schema:
type: string
format: binary
302:
description: Redirects to the current digest
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
chart = self.datamodel.get(pk, self._base_filters)
if not chart:
return self.response_404()
if kwargs["rison"].get("force", False):
cache_chart_thumbnail.delay(chart.id, force=True)
return self.response(202, message="OK Async")
# fetch the chart screenshot using the current user and cache if set
screenshot = ChartScreenshot(pk).get_from_cache(cache=thumbnail_cache)
# If not screenshot then send request to compute thumb to celery
if not screenshot:
cache_chart_thumbnail.delay(chart.id, force=True)
return self.response(202, message="OK Async")
# If digests
if chart.digest != digest:
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest
)
)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)

View File

@ -23,6 +23,10 @@ from superset.exceptions import SupersetException
from superset.utils import core as utils
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
thumbnail_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
}
def validate_json(value: Union[bytes, bytearray, str]) -> None:

View File

@ -19,6 +19,7 @@ import logging
from datetime import datetime
from subprocess import Popen
from sys import stdout
from typing import Type, Union
import click
import yaml
@ -454,6 +455,78 @@ def flower(port, address):
Popen(cmd, shell=True).wait()
@superset.command()
@with_appcontext
@click.option(
"--asynchronous",
"-a",
is_flag=True,
default=False,
help="Trigger commands to run remotely on a worker",
)
@click.option(
"--dashboards_only",
"-d",
is_flag=True,
default=False,
help="Only process dashboards",
)
@click.option(
"--charts_only", "-c", is_flag=True, default=False, help="Only process charts"
)
@click.option(
"--force",
"-f",
is_flag=True,
default=False,
help="Force refresh, even if previously cached",
)
@click.option("--model_id", "-i", multiple=True)
def compute_thumbnails(
asynchronous: bool,
dashboards_only: bool,
charts_only: bool,
force: bool,
model_id: int,
):
"""Compute thumbnails"""
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tasks.thumbnails import (
cache_chart_thumbnail,
cache_dashboard_thumbnail,
)
def compute_generic_thumbnail(
friendly_type: str,
model_cls: Union[Type[Dashboard], Type[Slice]],
model_id: int,
compute_func,
):
query = db.session.query(model_cls)
if model_id:
query = query.filter(model_cls.id.in_(model_id))
dashboards = query.all()
count = len(dashboards)
for i, model in enumerate(dashboards):
if asynchronous:
func = compute_func.delay
action = "Triggering"
else:
func = compute_func
action = "Processing"
msg = f'{action} {friendly_type} "{model}" ({i+1}/{count})'
click.secho(msg, fg="green")
func(model.id, force=force)
if not charts_only:
compute_generic_thumbnail(
"dashboard", Dashboard, model_id, cache_dashboard_thumbnail
)
if not dashboards_only:
compute_generic_thumbnail("chart", Slice, model_id, cache_chart_thumbnail)
@superset.command()
@with_appcontext
def load_test_users():

View File

@ -285,6 +285,8 @@ DEFAULT_FEATURE_FLAGS = {
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,
"KV_STORE": False,
"PRESTO_EXPAND_DATA": False,
# Exposes API endpoint to compute thumbnails
"THUMBNAILS": False,
"REDUCE_DASHBOARD_BOOTSTRAP_PAYLOAD": False,
"SHARE_QUERIES_VIA_KV_STORE": False,
"SIP_38_VIZ_REARCHITECTURE": False,
@ -312,6 +314,11 @@ FEATURE_FLAGS: Dict[str, bool] = {}
# return feature_flags_dict
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
# ---------------------------------------------------
# Thumbnail config (behind feature flag)
# ---------------------------------------------------
THUMBNAIL_SELENIUM_USER = "Admin"
THUMBNAIL_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
# ---------------------------------------------------
# Image and file configuration

View File

@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from typing import Any, Dict
from flask import g, make_response, request, Response
from flask import g, make_response, redirect, request, Response, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset.constants import RouteMethod
from superset.dashboards.commands.bulk_delete import BulkDeleteDashboardCommand
from superset.dashboards.commands.create import CreateDashboardCommand
@ -42,8 +45,11 @@ from superset.dashboards.schemas import (
DashboardPutSchema,
get_delete_ids_schema,
get_export_ids_schema,
thumbnail_query_schema,
)
from superset.models.dashboard import Dashboard
from superset.tasks.thumbnails import cache_dashboard_thumbnail
from superset.utils.screenshots import DashboardScreenshot
from superset.views.base import generate_download_headers
from superset.views.base_api import BaseSupersetModelRestApi, RelatedFieldFilter
from superset.views.filters import FilterRelatedOwners
@ -81,6 +87,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"url",
"slug",
"table_names",
"thumbnail_url",
]
order_columns = ["dashboard_title", "changed_on", "published", "changed_by_fk"]
list_columns = [
@ -93,6 +100,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"published",
"slug",
"url",
"thumbnail_url",
]
edit_columns = [
"dashboard_title",
@ -123,6 +131,11 @@ class DashboardRestApi(BaseSupersetModelRestApi):
}
allowed_rel_fields = {"owners"}
def __init__(self) -> None:
if is_feature_enabled("THUMBNAILS"):
self.include_route_methods = self.include_route_methods | {"thumbnail"}
super().__init__()
@expose("/", methods=["POST"])
@protect()
@safe
@ -151,12 +164,14 @@ class DashboardRestApi(BaseSupersetModelRestApi):
type: number
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
302:
description: Redirects to the current digest
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
@ -398,3 +413,87 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"Content-Disposition"
]
return resp
@expose("/<pk>/thumbnail/<digest>/", methods=["GET"])
@protect()
@safe
@rison(thumbnail_query_schema)
def thumbnail(
self, pk: int, digest: str, **kwargs: Dict[str, bool]
) -> WerkzeugResponse:
"""Get Dashboard thumbnail
---
get:
description: >-
Compute async or get already computed dashboard thumbnail from cache
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
- in: query
name: q
content:
application/json:
schema:
type: object
properties:
force:
type: boolean
default: false
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = self.datamodel.get(pk, self._base_filters)
if not dashboard:
return self.response_404()
# If force, request a screenshot from the workers
if kwargs["rison"].get("force", False):
cache_dashboard_thumbnail.delay(dashboard.id, force=True)
return self.response(202, message="OK Async")
# fetch the dashboard screenshot using the current user and cache if set
screenshot = DashboardScreenshot(pk).get_from_cache(cache=thumbnail_cache)
# If the screenshot does not exist, request one from the workers
if not screenshot:
cache_dashboard_thumbnail.delay(dashboard.id, force=True)
return self.response(202, message="OK Async")
# If digests
if dashboard.digest != digest:
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)

View File

@ -26,6 +26,10 @@ from superset.utils import core as utils
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
thumbnail_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
}
def validate_json(value: Union[bytes, bytearray, str]) -> None:

View File

@ -43,6 +43,7 @@ from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.slice import Slice as Slice
from superset.models.tags import DashboardUpdater
from superset.models.user_attributes import UserAttribute
from superset.tasks.thumbnails import cache_dashboard_thumbnail
from superset.utils import core as utils
from superset.utils.dashboard_filter_scopes_converter import (
convert_filter_scopes,
@ -184,6 +185,22 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
title = escape(self.dashboard_title or "<empty>")
return Markup(f'<a href="{self.url}">{title}</a>')
@property
def digest(self) -> str:
"""
Returns a MD5 HEX digest that makes this dashboard unique
"""
unique_string = f"{self.position_json}.{self.css}.{self.json_metadata}"
return utils.md5_hex(unique_string)
@property
def thumbnail_url(self) -> str:
"""
Returns a thumbnail URL with a HEX digest. We want to avoid browser cache
if the dashboard has changed
"""
return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"
@property
def changed_by_name(self):
if not self.changed_by:
@ -452,8 +469,20 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
)
def event_after_dashboard_changed( # pylint: disable=unused-argument
mapper, connection, target
):
cache_dashboard_thumbnail.delay(target.id, force=True)
# events for updating tags
if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert)
sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update)
sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete)
# events for updating tags
if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"):
sqla.event.listen(Dashboard, "after_insert", event_after_dashboard_changed)
sqla.event.listen(Dashboard, "after_update", event_after_dashboard_changed)

View File

@ -30,6 +30,7 @@ from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
from superset.legacy import update_time_range
from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.tags import ChartUpdater
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
@ -173,6 +174,21 @@ class Slice(
"changed_on": self.changed_on.isoformat(),
}
@property
def digest(self) -> str:
"""
Returns a MD5 HEX digest that makes this dashboard unique
"""
return utils.md5_hex(self.params)
@property
def thumbnail_url(self) -> str:
"""
Returns a thumbnail URL with a HEX digest. We want to avoid browser cache
if the dashboard has changed
"""
return f"/api/v1/chart/{self.id}/thumbnail/{self.digest}/"
@property
def json_data(self) -> str:
return json.dumps(self.data)
@ -306,6 +322,12 @@ def set_related_perm(mapper, connection, target):
target.schema_perm = ds.schema_perm
def event_after_chart_changed( # pylint: disable=unused-argument
mapper, connection, target
):
cache_chart_thumbnail.delay(target.id, force=True)
sqla.event.listen(Slice, "before_insert", set_related_perm)
sqla.event.listen(Slice, "before_update", set_related_perm)
@ -314,3 +336,8 @@ if is_feature_enabled("TAGGING_SYSTEM"):
sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert)
sqla.event.listen(Slice, "after_update", ChartUpdater.after_update)
sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete)
# events for updating tags
if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"):
sqla.event.listen(Slice, "after_insert", event_after_chart_changed)
sqla.event.listen(Slice, "after_update", event_after_chart_changed)

View File

@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
"""Utility functions used across Superset"""
import logging
from flask import current_app
from superset import app, security_manager, thumbnail_cache
from superset.extensions import celery_app
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
logger = logging.getLogger(__name__)
@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
def cache_chart_thumbnail(chart_id: int, force: bool = False):
with app.app_context():
if not thumbnail_cache:
logger.warning("No cache set, refusing to compute")
return None
logging.info(f"Caching chart {chart_id}")
screenshot = ChartScreenshot(model_id=chart_id)
user = security_manager.find_user(current_app.config["THUMBNAIL_SELENIUM_USER"])
screenshot.compute_and_cache(user=user, cache=thumbnail_cache, force=force)
@celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300)
def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False):
with app.app_context():
if not thumbnail_cache:
logging.warning("No cache set, refusing to compute")
return None
logger.info(f"Caching dashboard {dashboard_id}")
screenshot = DashboardScreenshot(model_id=dashboard_id)
user = security_manager.find_user(current_app.config["THUMBNAIL_SELENIUM_USER"])
screenshot.compute_and_cache(user=user, cache=thumbnail_cache, force=force)

View File

@ -26,12 +26,16 @@ class CacheManager:
self._tables_cache = None
self._cache = None
self._thumbnail_cache = None
def init_app(self, app: Flask) -> None:
self._cache = self._setup_cache(app, app.config["CACHE_CONFIG"])
self._tables_cache = self._setup_cache(
app, app.config["TABLE_NAMES_CACHE_CONFIG"]
)
self._thumbnail_cache = self._setup_cache(
app, app.config["THUMBNAIL_CACHE_CONFIG"]
)
@staticmethod
def _setup_cache(app: Flask, cache_config: CacheConfig) -> Cache:
@ -50,3 +54,7 @@ class CacheManager:
@property
def cache(self) -> Cache:
return self._cache
@property
def thumbnail_cache(self) -> Cache:
return self._thumbnail_cache

View File

@ -273,6 +273,10 @@ def dttm_from_timetuple(d: struct_time) -> datetime:
return datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec)
def md5_hex(data: str) -> str:
return hashlib.md5(data.encode()).hexdigest()
class DashboardEncoder(json.JSONEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -0,0 +1,329 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import time
import urllib.parse
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from flask import current_app, request, Response, session, url_for
from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import TimeoutException, WebDriverException
from selenium.webdriver import chrome, firefox
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from werkzeug.http import parse_cookie
logger = logging.getLogger(__name__)
try:
from PIL import Image # pylint: disable=import-error
except ModuleNotFoundError:
logger.info("No PIL installation found")
if TYPE_CHECKING:
# pylint: disable=unused-import
from flask_appbuilder.security.sqla.models import User
from flask_caching import Cache
# Time in seconds, we will wait for the page to load and render
SELENIUM_CHECK_INTERVAL = 2
SELENIUM_RETRIES = 5
SELENIUM_HEADSTART = 3
WindowSize = Tuple[int, int]
def get_auth_cookies(user: "User") -> List[Dict]:
# Login with the user specified to get the reports
with current_app.test_request_context("/login"):
login_user(user)
# A mock response object to get the cookie information from
response = Response()
current_app.session_interface.save_session(current_app, session, response)
cookies = []
# Set the cookies in the driver
for name, value in response.headers:
if name.lower() == "set-cookie":
cookie = parse_cookie(value)
cookies.append(cookie["session"])
return cookies
def auth_driver(driver: WebDriver, user: "User") -> WebDriver:
"""
Default AuthDriverFuncType type that sets a session cookie flask-login style
:return: WebDriver
"""
if user:
# Set the cookies in the driver
for cookie in get_auth_cookies(user):
info = dict(name="session", value=cookie)
driver.add_cookie(info)
elif request.cookies:
cookies = request.cookies
for k, v in cookies.items():
cookie = dict(name=k, value=v)
driver.add_cookie(cookie)
return driver
def headless_url(path: str) -> str:
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
def get_url_path(view: str, **kwargs) -> str:
with current_app.test_request_context():
return headless_url(url_for(view, **kwargs))
class AuthWebDriverProxy:
def __init__(
self,
driver_type: str,
window: Optional[WindowSize] = None,
auth_func: Optional[Callable] = None,
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
config_auth_func: Callable = current_app.config.get(
"WEBDRIVER_AUTH_FUNC", auth_driver
)
self._auth_func: Callable = auth_func or config_auth_func
def create(self) -> WebDriver:
if self._driver_type == "firefox":
driver_class = firefox.webdriver.WebDriver
options = firefox.options.Options()
elif self._driver_type == "chrome":
driver_class = chrome.webdriver.WebDriver
options = chrome.options.Options()
arg: str = f"--window-size={self._window[0]},{self._window[1]}"
options.add_argument(arg)
else:
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
# Prepare args for the webdriver init
options.add_argument("--headless")
kwargs: Dict = dict(options=options)
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
logger.info("Init selenium driver")
return driver_class(**kwargs)
def auth(self, user: "User") -> WebDriver:
# Setting cookies requires doing a request first
driver = self.create()
driver.get(headless_url("/login/"))
return self._auth_func(driver, user)
@staticmethod
def destroy(driver: WebDriver, tries=2):
"""Destroy a driver"""
# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions
try:
retry_call(driver.close, tries=tries)
except Exception: # pylint: disable=broad-except
pass
try:
driver.quit()
except Exception: # pylint: disable=broad-except
pass
def get_screenshot(
self, url: str, element_name: str, user: "User", retries: int = SELENIUM_RETRIES
) -> Optional[bytes]:
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
img: Optional[bytes] = None
logger.debug(f"Sleeping for {SELENIUM_HEADSTART} seconds")
time.sleep(SELENIUM_HEADSTART)
try:
logger.debug(f"Wait for the presence of {element_name}")
element = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.CLASS_NAME, element_name))
)
logger.debug(f"Wait for .loading to be done")
WebDriverWait(driver, 60).until_not(
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
)
logger.info("Taking a PNG screenshot")
img = element.screenshot_as_png
except TimeoutException:
logger.error("Selenium timed out")
except WebDriverException as ex:
logger.error(ex)
# Some webdrivers do not support screenshots for elements.
# In such cases, take a screenshot of the entire page.
img = driver.screenshot() # pylint: disable=no-member
finally:
self.destroy(driver, retries)
return img
class BaseScreenshot:
driver_type = "chrome"
thumbnail_type: str = ""
element: str = ""
window_size: WindowSize = (800, 600)
thumb_size: WindowSize = (400, 300)
def __init__(self, model_id: int):
self.model_id: int = model_id
self.screenshot: Optional[bytes] = None
self._driver = AuthWebDriverProxy(self.driver_type, self.window_size)
@property
def cache_key(self) -> str:
return f"thumb__{self.thumbnail_type}__{self.model_id}"
@property
def url(self) -> str:
raise NotImplementedError()
def get_screenshot(self, user: "User") -> Optional[bytes]:
self.screenshot = self._driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get(
self,
user: "User" = None,
cache: "Cache" = None,
thumb_size: Optional[WindowSize] = None,
) -> Optional[BytesIO]:
"""
Get thumbnail screenshot has BytesIO from cache or fetch
:param user: None to use current user or User Model to login and fetch
:param cache: The cache to use
:param thumb_size: Override thumbnail site
"""
payload: Optional[bytes] = None
thumb_size = thumb_size or self.thumb_size
if cache:
payload = cache.get(self.cache_key)
if not payload:
payload = self.compute_and_cache(
user=user, thumb_size=thumb_size, cache=cache
)
else:
logger.info(f"Loaded thumbnail from cache: {self.cache_key}")
if payload:
return BytesIO(payload)
return None
def get_from_cache(self, cache: "Cache") -> Optional[BytesIO]:
payload = cache.get(self.cache_key)
if payload:
return BytesIO(payload)
return None
def compute_and_cache( # pylint: disable=too-many-arguments
self,
user: "User" = None,
thumb_size: Optional[WindowSize] = None,
cache: "Cache" = None,
force: bool = True,
) -> Optional[bytes]:
"""
Fetches the screenshot, computes the thumbnail and caches the result
:param user: If no user is given will use the current context
:param cache: The cache to keep the thumbnail payload
:param window_size: The window size from which will process the thumb
:param thumb_size: The final thumbnail size
:param force: Will force the computation even if it's already cached
:return: Image payload
"""
cache_key = self.cache_key
if not force and cache and cache.get(cache_key):
logger.info("Thumb already cached, skipping...")
return None
thumb_size = thumb_size or self.thumb_size
logger.info(f"Processing url for thumbnail: {cache_key}")
payload = None
# Assuming all sorts of things can go wrong with Selenium
try:
payload = self.get_screenshot(user=user)
except Exception as ex: # pylint: disable=broad-except
logger.error("Failed at generating thumbnail %s", ex)
if payload and self.window_size != thumb_size:
try:
payload = self.resize_image(payload, thumb_size=thumb_size)
except Exception as ex: # pylint: disable=broad-except
logger.error("Failed at resizing thumbnail %s", ex)
payload = None
if payload and cache:
logger.info(f"Caching thumbnail: {cache_key} {cache}")
cache.set(cache_key, payload)
return payload
@classmethod
def resize_image(
cls,
img_bytes: bytes,
output: str = "png",
thumb_size: Optional[WindowSize] = None,
crop: bool = True,
) -> bytes:
thumb_size = thumb_size or cls.thumb_size
img = Image.open(BytesIO(img_bytes))
logger.debug(f"Selenium image size: {img.size}")
if crop and img.size[1] != cls.window_size[1]:
desired_ratio = float(cls.window_size[1]) / cls.window_size[0]
desired_width = int(img.size[0] * desired_ratio)
logger.debug(f"Cropping to: {img.size[0]}*{desired_width}")
img = img.crop((0, 0, img.size[0], desired_width))
logger.debug(f"Resizing to {thumb_size}")
img = img.resize(thumb_size, Image.ANTIALIAS)
new_img = BytesIO()
if output != "png":
img = img.convert("RGB")
img.save(new_img, output)
new_img.seek(0)
return new_img.read()
class ChartScreenshot(BaseScreenshot):
thumbnail_type: str = "chart"
element: str = "chart-container"
window_size: WindowSize = (600, int(600 * 0.75))
thumb_size: WindowSize = (300, int(300 * 0.75))
@property
def url(self) -> str:
return get_url_path("Superset.slice", slice_id=self.model_id, standalone="true")
class DashboardScreenshot(BaseScreenshot):
thumbnail_type: str = "dashboard"
element: str = "grid-container"
window_size: WindowSize = (1600, int(1600 * 0.75))
thumb_size: WindowSize = (400, int(400 * 0.75))
@property
def url(self) -> str:
return get_url_path("Superset.dashboard", dashboard_id=self.model_id)

View File

@ -22,6 +22,8 @@ from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from superset.stats_logger import BaseStatsLogger
logger = logging.getLogger(__name__)
get_related_schema = {
"type": "object",
@ -57,6 +59,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
"bulk_delete": "delete",
"info": "list",
"related": "list",
"thumbnail": "list",
"refresh": "edit",
"data": "list",
}
@ -88,9 +91,9 @@ class BaseSupersetModelRestApi(ModelRestApi):
""" # pylint: disable=pointless-string-statement
allowed_rel_fields: Set[str] = set()
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.stats_logger = None
self.stats_logger = BaseStatsLogger()
def create_blueprint(self, appbuilder, *args, **kwargs):
self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"]

View File

@ -109,6 +109,7 @@ class DashboardApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
"url": f"/superset/dashboard/slug1/",
"slug": "slug1",
"table_names": "",
"thumbnail_url": dashboard.thumbnail_url,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertIn("changed_on", data["result"])

View File

@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# type: ignore
from copy import copy
from flask import Flask
from werkzeug.contrib.cache import RedisCache
from superset.config import * # type: ignore
AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
DEBUG = True
SUPERSET_WEBSERVER_PORT = 8081
# Allowing SQLALCHEMY_DATABASE_URI to be defined as an env var for
# continuous integration
if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ:
SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"]
SQL_SELECT_AS_CTA = True
SQL_MAX_ROW = 666
def GET_FEATURE_FLAGS_FUNC(ff):
ff_copy = copy(ff)
ff_copy["super"] = "set"
return ff_copy
TESTING = True
WTF_CSRF_ENABLED = False
PUBLIC_ROLE_LIKE_GAMMA = True
AUTH_ROLE_PUBLIC = "Public"
EMAIL_NOTIFICATIONS = False
CACHE_CONFIG = {"CACHE_TYPE": "simple"}
class CeleryConfig(object):
BROKER_URL = "redis://localhost"
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks.thumbnails")
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
CONCURRENCY = 1
CELERY_CONFIG = CeleryConfig
FEATURE_FLAGS = {
"foo": "bar",
"KV_STORE": False,
"SHARE_QUERIES_VIA_KV_STORE": False,
"THUMBNAILS": True,
"THUMBNAILS_SQLA_LISTENERS": False,
}
def init_thumbnail_cache(app: Flask) -> RedisCache:
return RedisCache(
host="localhost", key_prefix="superset_thumbnails_", default_timeout=10000
)
THUMBNAIL_CACHE_CONFIG = init_thumbnail_cache

261
tests/thumbnails_tests.py Normal file
View File

@ -0,0 +1,261 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# from superset import db
# from superset.models.dashboard import Dashboard
import subprocess
import urllib.request
from unittest import skipUnless
from unittest.mock import patch
from flask_testing import LiveServerTestCase
from sqlalchemy.sql import func
from superset import db, is_feature_enabled, security_manager, thumbnail_cache
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.screenshots import (
ChartScreenshot,
DashboardScreenshot,
get_auth_cookies,
)
from tests.test_app import app
from .base_tests import SupersetTestCase
class CeleryStartMixin:
@classmethod
def setUpClass(cls):
with app.app_context():
from werkzeug.contrib.cache import RedisCache
class CeleryConfig(object):
BROKER_URL = "redis://localhost"
CELERY_IMPORTS = ("superset.tasks.thumbnails",)
CONCURRENCY = 1
app.config["CELERY_CONFIG"] = CeleryConfig
def init_thumbnail_cache(app) -> RedisCache:
return RedisCache(
host="localhost",
key_prefix="superset_thumbnails_",
default_timeout=10000,
)
app.config["THUMBNAIL_CACHE_CONFIG"] = init_thumbnail_cache
base_dir = app.config["BASE_DIR"]
worker_command = base_dir + "/bin/superset worker -w 2"
subprocess.Popen(
worker_command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
@classmethod
def tearDownClass(cls):
subprocess.call(
"ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True
)
subprocess.call(
"ps auxww | grep 'superset worker' | awk '{print $2}' | xargs kill -9",
shell=True,
)
class ThumbnailsSeleniumLive(CeleryStartMixin, LiveServerTestCase):
def create_app(self):
return app
def url_open_auth(self, username: str, url: str):
admin_user = security_manager.find_user(username=username)
cookies = {}
for cookie in get_auth_cookies(admin_user):
cookies["session"] = cookie
opener = urllib.request.build_opener()
opener.addheaders.append(("Cookie", f"session={cookies['session']}"))
return opener.open(f"{self.get_server_url()}/{url}")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self):
"""
Thumbnails: Simple get async dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get:
response = self.url_open_auth(
"admin",
f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/",
)
self.assertEqual(response.getcode(), 202)
class ThumbnailsTests(CeleryStartMixin, SupersetTestCase):
mock_image = b"bytes mock image"
def test_dashboard_thumbnail_disabled(self):
"""
Thumbnails: Dashboard thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_chart_thumbnail_disabled(self):
"""
Thumbnails: Chart thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
chart = db.session.query(Slice).all()[0]
self.login(username="admin")
uri = f"api/v1/chart/{chart}/thumbnail/{chart.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self):
"""
Thumbnails: Simple get async dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
with patch(
"superset.tasks.thumbnails.cache_dashboard_thumbnail.delay"
) as mock_task:
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 202)
mock_task.assert_called_with(dashboard.id, force=True)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_notfound(self):
"""
Thumbnails: Simple get async dashboard not found
"""
max_id = db.session.query(func.max(Dashboard.id)).scalar()
self.login(username="admin")
uri = f"api/v1/dashboard/{max_id + 1}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self):
"""
Thumbnails: Simple get async dashboard not allowed
"""
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="gamma")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_screenshot(self):
"""
Thumbnails: Simple get async chart screenshot
"""
chart = db.session.query(Slice).all()[0]
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/"
with patch(
"superset.tasks.thumbnails.cache_chart_thumbnail.delay"
) as mock_task:
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 202)
mock_task.assert_called_with(chart.id, force=True)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_notfound(self):
"""
Thumbnails: Simple get async chart not found
"""
max_id = db.session.query(func.max(Slice.id)).scalar()
self.login(username="admin")
uri = f"api/v1/chart/{max_id + 1}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_wrong_digest(self):
"""
Thumbnails: Simple get chart with wrong digest
"""
chart = db.session.query(Slice).all()[0]
# Cache a test "image"
screenshot = ChartScreenshot(model_id=chart.id)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(rv, f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_screenshot(self):
"""
Thumbnails: Simple get cached dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
# Cache a test "image"
screenshot = DashboardScreenshot(model_id=dashboard.id)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_screenshot(self):
"""
Thumbnails: Simple get cached chart screenshot
"""
chart = db.session.query(Slice).all()[0]
# Cache a test "image"
screenshot = ChartScreenshot(model_id=chart.id)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_wrong_digest(self):
"""
Thumbnails: Simple get dashboard with wrong digest
"""
dashboard = db.session.query(Dashboard).all()[0]
# Cache a test "image"
screenshot = DashboardScreenshot(model_id=dashboard.id)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(
rv, f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
)

View File

@ -33,6 +33,14 @@ setenv =
whitelist_externals =
npm
[testenv:thumbnails]
setenv =
SUPERSET_CONFIG = tests.superset_test_config_thumbnails
deps =
-rrequirements.txt
-rrequirements-dev.txt
.[postgres]
[testenv:black]
commands =
black --check setup.py superset tests