From d81f720502375101b43ffe9e2e6c28c0687b00f9 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Wed, 15 Apr 2020 09:40:14 +0100 Subject: [PATCH] [thumbnails] API and celery task for dashboards and charts (#8947) --- docs/installation.rst | 68 +++++ requirements-dev.txt | 1 + setup.py | 1 + superset/__init__.py | 1 + superset/charts/api.py | 83 +++++- superset/charts/schemas.py | 4 + superset/cli.py | 73 +++++ superset/config.py | 7 + superset/dashboards/api.py | 107 +++++++- superset/dashboards/schemas.py | 4 + superset/models/dashboard.py | 29 ++ superset/models/slice.py | 27 ++ superset/tasks/thumbnails.py | 53 ++++ superset/utils/cache_manager.py | 8 + superset/utils/core.py | 4 + superset/utils/screenshots.py | 329 +++++++++++++++++++++++ superset/views/base_api.py | 7 +- tests/dashboards/api_tests.py | 1 + tests/superset_test_config_thumbnails.py | 78 ++++++ tests/thumbnails_tests.py | 261 ++++++++++++++++++ tox.ini | 8 + 21 files changed, 1141 insertions(+), 13 deletions(-) create mode 100644 superset/tasks/thumbnails.py create mode 100644 superset/utils/screenshots.py create mode 100644 tests/superset_test_config_thumbnails.py create mode 100644 tests/thumbnails_tests.py diff --git a/docs/installation.rst b/docs/installation.rst index 35280570f6..baa14a67f5 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -647,6 +647,74 @@ section in `config.py`: This will cache all the charts in the top 5 most popular dashboards every hour. For other strategies, check the `superset/tasks/cache.py` file. +Caching Thumbnails +------------------ + +This is an optional feature that can be turned on by activating it's feature flag on config: + +.. code-block:: python + + FEATURE_FLAGS = { + "THUMBNAILS": True, + "THUMBNAILS_SQLA_LISTENERS": True, + } + + +For this feature you will need a cache system and celery workers. All thumbnails are store on cache and are processed +asynchronously by the workers. + +An example config where images are stored on S3 could be: + +.. code-block:: python + + from flask import Flask + from s3cache.s3cache import S3Cache + + ... + + class CeleryConfig(object): + BROKER_URL = "redis://localhost:6379/0" + CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks", "superset.tasks.thumbnails") + CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + CELERYD_PREFETCH_MULTIPLIER = 10 + CELERY_ACKS_LATE = True + + + CELERY_CONFIG = CeleryConfig + + def init_thumbnail_cache(app: Flask) -> S3Cache: + return S3Cache("bucket_name", 'thumbs_cache/') + + + THUMBNAIL_CACHE_CONFIG = init_thumbnail_cache + # Async selenium thumbnail task will use the following user + THUMBNAIL_SELENIUM_USER = "Admin" + +Using the above example cache keys for dashboards will be `superset_thumb__dashboard__{ID}` + +You can override the base URL for selenium using: + +.. code-block:: python + + WEBDRIVER_BASEURL = "https://superset.company.com" + + +Additional selenium web drive config can be set using `WEBDRIVER_CONFIGURATION` + +You can implement a custom function to authenticate selenium, the default uses flask-login session cookie. +An example of a custom function signature: + +.. code-block:: python + + def auth_driver(driver: WebDriver, user: "User") -> WebDriver: + pass + + +Then on config: + +.. code-block:: python + + WEBDRIVER_AUTH_FUNC = auth_driver Deeper SQLAlchemy integration ----------------------------- diff --git a/requirements-dev.txt b/requirements-dev.txt index 77820f4d49..a408b03ea0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,3 +33,4 @@ redis==3.2.1 requests==2.22.0 statsd==3.3.0 tox==3.11.1 +pillow==7.0.0 diff --git a/setup.py b/setup.py index 6ca8b25145..6c1484914e 100644 --- a/setup.py +++ b/setup.py @@ -118,6 +118,7 @@ setup( "hana": ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"], "dremio": ["sqlalchemy_dremio>=1.1.0"], "cockroachdb": ["cockroachdb==0.3.3"], + "thumbnails": ["Pillow>=7.0.0, <8.0.0"], }, python_requires="~=3.6", author="Apache Software Foundation", diff --git a/superset/__init__.py b/superset/__init__.py index cc92f3d9ce..3e26c3fb74 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -51,3 +51,4 @@ results_backend_use_msgpack = LocalProxy( lambda: results_backend_manager.should_use_msgpack ) tables_cache = LocalProxy(lambda: cache_manager.tables_cache) +thumbnail_cache = LocalProxy(lambda: cache_manager.thumbnail_cache) diff --git a/superset/charts/api.py b/superset/charts/api.py index 74776c1934..be0df118b2 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any +from typing import Any, Dict import simplejson -from flask import g, make_response, request, Response +from flask import g, make_response, redirect, request, Response, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext +from werkzeug.wrappers import Response as WerkzeugResponse +from werkzeug.wsgi import FileWrapper +from superset import is_feature_enabled, thumbnail_cache from superset.charts.commands.bulk_delete import BulkDeleteChartCommand from superset.charts.commands.create import CreateChartCommand from superset.charts.commands.delete import DeleteChartCommand @@ -41,13 +44,16 @@ from superset.charts.schemas import ( ChartPostSchema, ChartPutSchema, get_delete_ids_schema, + thumbnail_query_schema, ) from superset.common.query_context import QueryContext from superset.constants import RouteMethod from superset.exceptions import SupersetSecurityException from superset.extensions import event_logger, security_manager from superset.models.slice import Slice +from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils.core import json_int_dttm_ser +from superset.utils.screenshots import ChartScreenshot from superset.views.base_api import BaseSupersetModelRestApi, RelatedFieldFilter from superset.views.filters import FilterRelatedOwners @@ -131,6 +137,11 @@ class ChartRestApi(BaseSupersetModelRestApi): } allowed_rel_fields = {"owners"} + def __init__(self) -> None: + if is_feature_enabled("THUMBNAILS"): + self.include_route_methods = self.include_route_methods | {"thumbnail"} + super().__init__() + @expose("/", methods=["POST"]) @protect() @safe @@ -440,13 +451,9 @@ class ChartRestApi(BaseSupersetModelRestApi): type: object 400: $ref: '#/components/responses/400' - 401: - $ref: '#/components/responses/401' - 404: - $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' - """ + """ if not request.is_json: return self.response_400(message="Request is not JSON") try: @@ -464,3 +471,65 @@ class ChartRestApi(BaseSupersetModelRestApi): resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp + + @expose("//thumbnail//", methods=["GET"]) + @protect() + @rison(thumbnail_query_schema) + @safe + def thumbnail( + self, pk: int, digest: str, **kwargs: Dict[str, bool] + ) -> WerkzeugResponse: + """Get Chart thumbnail + --- + get: + description: Compute or get already computed chart thumbnail from cache + parameters: + - in: path + schema: + type: integer + name: pk + - in: path + schema: + type: string + name: sha + responses: + 200: + description: Chart thumbnail image + content: + image/*: + schema: + type: string + format: binary + 302: + description: Redirects to the current digest + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + chart = self.datamodel.get(pk, self._base_filters) + if not chart: + return self.response_404() + if kwargs["rison"].get("force", False): + cache_chart_thumbnail.delay(chart.id, force=True) + return self.response(202, message="OK Async") + # fetch the chart screenshot using the current user and cache if set + screenshot = ChartScreenshot(pk).get_from_cache(cache=thumbnail_cache) + # If not screenshot then send request to compute thumb to celery + if not screenshot: + cache_chart_thumbnail.delay(chart.id, force=True) + return self.response(202, message="OK Async") + # If digests + if chart.digest != digest: + return redirect( + url_for( + f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest + ) + ) + return Response( + FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True + ) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 96dc3d3606..bf1b57b321 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -23,6 +23,10 @@ from superset.exceptions import SupersetException from superset.utils import core as utils get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} +thumbnail_query_schema = { + "type": "object", + "properties": {"force": {"type": "boolean"}}, +} def validate_json(value: Union[bytes, bytearray, str]) -> None: diff --git a/superset/cli.py b/superset/cli.py index 2da47e162d..3bf10d7592 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -19,6 +19,7 @@ import logging from datetime import datetime from subprocess import Popen from sys import stdout +from typing import Type, Union import click import yaml @@ -454,6 +455,78 @@ def flower(port, address): Popen(cmd, shell=True).wait() +@superset.command() +@with_appcontext +@click.option( + "--asynchronous", + "-a", + is_flag=True, + default=False, + help="Trigger commands to run remotely on a worker", +) +@click.option( + "--dashboards_only", + "-d", + is_flag=True, + default=False, + help="Only process dashboards", +) +@click.option( + "--charts_only", "-c", is_flag=True, default=False, help="Only process charts" +) +@click.option( + "--force", + "-f", + is_flag=True, + default=False, + help="Force refresh, even if previously cached", +) +@click.option("--model_id", "-i", multiple=True) +def compute_thumbnails( + asynchronous: bool, + dashboards_only: bool, + charts_only: bool, + force: bool, + model_id: int, +): + """Compute thumbnails""" + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.tasks.thumbnails import ( + cache_chart_thumbnail, + cache_dashboard_thumbnail, + ) + + def compute_generic_thumbnail( + friendly_type: str, + model_cls: Union[Type[Dashboard], Type[Slice]], + model_id: int, + compute_func, + ): + query = db.session.query(model_cls) + if model_id: + query = query.filter(model_cls.id.in_(model_id)) + dashboards = query.all() + count = len(dashboards) + for i, model in enumerate(dashboards): + if asynchronous: + func = compute_func.delay + action = "Triggering" + else: + func = compute_func + action = "Processing" + msg = f'{action} {friendly_type} "{model}" ({i+1}/{count})' + click.secho(msg, fg="green") + func(model.id, force=force) + + if not charts_only: + compute_generic_thumbnail( + "dashboard", Dashboard, model_id, cache_dashboard_thumbnail + ) + if not dashboards_only: + compute_generic_thumbnail("chart", Slice, model_id, cache_chart_thumbnail) + + @superset.command() @with_appcontext def load_test_users(): diff --git a/superset/config.py b/superset/config.py index 91866fbfe2..b719397ce4 100644 --- a/superset/config.py +++ b/superset/config.py @@ -285,6 +285,8 @@ DEFAULT_FEATURE_FLAGS = { "ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False, "KV_STORE": False, "PRESTO_EXPAND_DATA": False, + # Exposes API endpoint to compute thumbnails + "THUMBNAILS": False, "REDUCE_DASHBOARD_BOOTSTRAP_PAYLOAD": False, "SHARE_QUERIES_VIA_KV_STORE": False, "SIP_38_VIZ_REARCHITECTURE": False, @@ -312,6 +314,11 @@ FEATURE_FLAGS: Dict[str, bool] = {} # return feature_flags_dict GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None +# --------------------------------------------------- +# Thumbnail config (behind feature flag) +# --------------------------------------------------- +THUMBNAIL_SELENIUM_USER = "Admin" +THUMBNAIL_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"} # --------------------------------------------------- # Image and file configuration diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index eb4f796cd0..993795142b 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any +from typing import Any, Dict -from flask import g, make_response, request, Response +from flask import g, make_response, redirect, request, Response, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext +from werkzeug.wrappers import Response as WerkzeugResponse +from werkzeug.wsgi import FileWrapper +from superset import is_feature_enabled, thumbnail_cache from superset.constants import RouteMethod from superset.dashboards.commands.bulk_delete import BulkDeleteDashboardCommand from superset.dashboards.commands.create import CreateDashboardCommand @@ -42,8 +45,11 @@ from superset.dashboards.schemas import ( DashboardPutSchema, get_delete_ids_schema, get_export_ids_schema, + thumbnail_query_schema, ) from superset.models.dashboard import Dashboard +from superset.tasks.thumbnails import cache_dashboard_thumbnail +from superset.utils.screenshots import DashboardScreenshot from superset.views.base import generate_download_headers from superset.views.base_api import BaseSupersetModelRestApi, RelatedFieldFilter from superset.views.filters import FilterRelatedOwners @@ -81,6 +87,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): "url", "slug", "table_names", + "thumbnail_url", ] order_columns = ["dashboard_title", "changed_on", "published", "changed_by_fk"] list_columns = [ @@ -93,6 +100,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): "published", "slug", "url", + "thumbnail_url", ] edit_columns = [ "dashboard_title", @@ -123,6 +131,11 @@ class DashboardRestApi(BaseSupersetModelRestApi): } allowed_rel_fields = {"owners"} + def __init__(self) -> None: + if is_feature_enabled("THUMBNAILS"): + self.include_route_methods = self.include_route_methods | {"thumbnail"} + super().__init__() + @expose("/", methods=["POST"]) @protect() @safe @@ -151,12 +164,14 @@ class DashboardRestApi(BaseSupersetModelRestApi): type: number result: $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 302: + description: Redirects to the current digest 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' - 422: - $ref: '#/components/responses/422' + 404: + $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ @@ -398,3 +413,87 @@ class DashboardRestApi(BaseSupersetModelRestApi): "Content-Disposition" ] return resp + + @expose("//thumbnail//", methods=["GET"]) + @protect() + @safe + @rison(thumbnail_query_schema) + def thumbnail( + self, pk: int, digest: str, **kwargs: Dict[str, bool] + ) -> WerkzeugResponse: + """Get Dashboard thumbnail + --- + get: + description: >- + Compute async or get already computed dashboard thumbnail from cache + parameters: + - in: path + schema: + type: integer + name: pk + - in: path + name: digest + description: A hex digest that makes this dashboard unique + schema: + type: string + - in: query + name: q + content: + application/json: + schema: + type: object + properties: + force: + type: boolean + default: false + responses: + 200: + description: Dashboard thumbnail image + content: + image/*: + schema: + type: string + format: binary + 202: + description: Thumbnail does not exist on cache, fired async to compute + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + dashboard = self.datamodel.get(pk, self._base_filters) + if not dashboard: + return self.response_404() + # If force, request a screenshot from the workers + if kwargs["rison"].get("force", False): + cache_dashboard_thumbnail.delay(dashboard.id, force=True) + return self.response(202, message="OK Async") + # fetch the dashboard screenshot using the current user and cache if set + screenshot = DashboardScreenshot(pk).get_from_cache(cache=thumbnail_cache) + # If the screenshot does not exist, request one from the workers + if not screenshot: + cache_dashboard_thumbnail.delay(dashboard.id, force=True) + return self.response(202, message="OK Async") + # If digests + if dashboard.digest != digest: + return redirect( + url_for( + f"{self.__class__.__name__}.thumbnail", + pk=pk, + digest=dashboard.digest, + ) + ) + return Response( + FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True + ) diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index 902130e5b3..201c4ca49d 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -26,6 +26,10 @@ from superset.utils import core as utils get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} +thumbnail_query_schema = { + "type": "object", + "properties": {"force": {"type": "boolean"}}, +} def validate_json(value: Union[bytes, bytearray, str]) -> None: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index cd3b0f95fd..c86f8ff83f 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -43,6 +43,7 @@ from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.slice import Slice as Slice from superset.models.tags import DashboardUpdater from superset.models.user_attributes import UserAttribute +from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.utils import core as utils from superset.utils.dashboard_filter_scopes_converter import ( convert_filter_scopes, @@ -184,6 +185,22 @@ class Dashboard( # pylint: disable=too-many-instance-attributes title = escape(self.dashboard_title or "") return Markup(f'{title}') + @property + def digest(self) -> str: + """ + Returns a MD5 HEX digest that makes this dashboard unique + """ + unique_string = f"{self.position_json}.{self.css}.{self.json_metadata}" + return utils.md5_hex(unique_string) + + @property + def thumbnail_url(self) -> str: + """ + Returns a thumbnail URL with a HEX digest. We want to avoid browser cache + if the dashboard has changed + """ + return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/" + @property def changed_by_name(self): if not self.changed_by: @@ -452,8 +469,20 @@ class Dashboard( # pylint: disable=too-many-instance-attributes ) +def event_after_dashboard_changed( # pylint: disable=unused-argument + mapper, connection, target +): + cache_dashboard_thumbnail.delay(target.id, force=True) + + # events for updating tags if is_feature_enabled("TAGGING_SYSTEM"): sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert) sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update) sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete) + + +# events for updating tags +if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"): + sqla.event.listen(Dashboard, "after_insert", event_after_dashboard_changed) + sqla.event.listen(Dashboard, "after_update", event_after_dashboard_changed) diff --git a/superset/models/slice.py b/superset/models/slice.py index 59544407c8..24a05f205a 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -30,6 +30,7 @@ from superset import ConnectorRegistry, db, is_feature_enabled, security_manager from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.tags import ChartUpdater +from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils import core as utils if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): @@ -173,6 +174,21 @@ class Slice( "changed_on": self.changed_on.isoformat(), } + @property + def digest(self) -> str: + """ + Returns a MD5 HEX digest that makes this dashboard unique + """ + return utils.md5_hex(self.params) + + @property + def thumbnail_url(self) -> str: + """ + Returns a thumbnail URL with a HEX digest. We want to avoid browser cache + if the dashboard has changed + """ + return f"/api/v1/chart/{self.id}/thumbnail/{self.digest}/" + @property def json_data(self) -> str: return json.dumps(self.data) @@ -306,6 +322,12 @@ def set_related_perm(mapper, connection, target): target.schema_perm = ds.schema_perm +def event_after_chart_changed( # pylint: disable=unused-argument + mapper, connection, target +): + cache_chart_thumbnail.delay(target.id, force=True) + + sqla.event.listen(Slice, "before_insert", set_related_perm) sqla.event.listen(Slice, "before_update", set_related_perm) @@ -314,3 +336,8 @@ if is_feature_enabled("TAGGING_SYSTEM"): sqla.event.listen(Slice, "after_insert", ChartUpdater.after_insert) sqla.event.listen(Slice, "after_update", ChartUpdater.after_update) sqla.event.listen(Slice, "after_delete", ChartUpdater.after_delete) + +# events for updating tags +if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"): + sqla.event.listen(Slice, "after_insert", event_after_chart_changed) + sqla.event.listen(Slice, "after_update", event_after_chart_changed) diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py new file mode 100644 index 0000000000..72c7bdaf67 --- /dev/null +++ b/superset/tasks/thumbnails.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=C,R,W + +"""Utility functions used across Superset""" + +import logging + +from flask import current_app + +from superset import app, security_manager, thumbnail_cache +from superset.extensions import celery_app +from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300) +def cache_chart_thumbnail(chart_id: int, force: bool = False): + with app.app_context(): + if not thumbnail_cache: + logger.warning("No cache set, refusing to compute") + return None + logging.info(f"Caching chart {chart_id}") + screenshot = ChartScreenshot(model_id=chart_id) + user = security_manager.find_user(current_app.config["THUMBNAIL_SELENIUM_USER"]) + screenshot.compute_and_cache(user=user, cache=thumbnail_cache, force=force) + + +@celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300) +def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False): + with app.app_context(): + if not thumbnail_cache: + logging.warning("No cache set, refusing to compute") + return None + logger.info(f"Caching dashboard {dashboard_id}") + screenshot = DashboardScreenshot(model_id=dashboard_id) + user = security_manager.find_user(current_app.config["THUMBNAIL_SELENIUM_USER"]) + screenshot.compute_and_cache(user=user, cache=thumbnail_cache, force=force) diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py index a098a16d3c..4a625ea6d0 100644 --- a/superset/utils/cache_manager.py +++ b/superset/utils/cache_manager.py @@ -26,12 +26,16 @@ class CacheManager: self._tables_cache = None self._cache = None + self._thumbnail_cache = None def init_app(self, app: Flask) -> None: self._cache = self._setup_cache(app, app.config["CACHE_CONFIG"]) self._tables_cache = self._setup_cache( app, app.config["TABLE_NAMES_CACHE_CONFIG"] ) + self._thumbnail_cache = self._setup_cache( + app, app.config["THUMBNAIL_CACHE_CONFIG"] + ) @staticmethod def _setup_cache(app: Flask, cache_config: CacheConfig) -> Cache: @@ -50,3 +54,7 @@ class CacheManager: @property def cache(self) -> Cache: return self._cache + + @property + def thumbnail_cache(self) -> Cache: + return self._thumbnail_cache diff --git a/superset/utils/core.py b/superset/utils/core.py index e72c8ccabd..ca49989a5f 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -273,6 +273,10 @@ def dttm_from_timetuple(d: struct_time) -> datetime: return datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) +def md5_hex(data: str) -> str: + return hashlib.md5(data.encode()).hexdigest() + + class DashboardEncoder(json.JSONEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py new file mode 100644 index 0000000000..18283e7f0d --- /dev/null +++ b/superset/utils/screenshots.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import time +import urllib.parse +from io import BytesIO +from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING + +from flask import current_app, request, Response, session, url_for +from flask_login import login_user +from retry.api import retry_call +from selenium.common.exceptions import TimeoutException, WebDriverException +from selenium.webdriver import chrome, firefox +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait +from werkzeug.http import parse_cookie + +logger = logging.getLogger(__name__) + +try: + from PIL import Image # pylint: disable=import-error +except ModuleNotFoundError: + logger.info("No PIL installation found") + +if TYPE_CHECKING: + # pylint: disable=unused-import + from flask_appbuilder.security.sqla.models import User + from flask_caching import Cache + +# Time in seconds, we will wait for the page to load and render +SELENIUM_CHECK_INTERVAL = 2 +SELENIUM_RETRIES = 5 +SELENIUM_HEADSTART = 3 + +WindowSize = Tuple[int, int] + + +def get_auth_cookies(user: "User") -> List[Dict]: + # Login with the user specified to get the reports + with current_app.test_request_context("/login"): + login_user(user) + # A mock response object to get the cookie information from + response = Response() + current_app.session_interface.save_session(current_app, session, response) + + cookies = [] + + # Set the cookies in the driver + for name, value in response.headers: + if name.lower() == "set-cookie": + cookie = parse_cookie(value) + cookies.append(cookie["session"]) + return cookies + + +def auth_driver(driver: WebDriver, user: "User") -> WebDriver: + """ + Default AuthDriverFuncType type that sets a session cookie flask-login style + :return: WebDriver + """ + if user: + # Set the cookies in the driver + for cookie in get_auth_cookies(user): + info = dict(name="session", value=cookie) + driver.add_cookie(info) + elif request.cookies: + cookies = request.cookies + for k, v in cookies.items(): + cookie = dict(name=k, value=v) + driver.add_cookie(cookie) + return driver + + +def headless_url(path: str) -> str: + return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path) + + +def get_url_path(view: str, **kwargs) -> str: + with current_app.test_request_context(): + return headless_url(url_for(view, **kwargs)) + + +class AuthWebDriverProxy: + def __init__( + self, + driver_type: str, + window: Optional[WindowSize] = None, + auth_func: Optional[Callable] = None, + ): + self._driver_type = driver_type + self._window: WindowSize = window or (800, 600) + config_auth_func: Callable = current_app.config.get( + "WEBDRIVER_AUTH_FUNC", auth_driver + ) + self._auth_func: Callable = auth_func or config_auth_func + + def create(self) -> WebDriver: + if self._driver_type == "firefox": + driver_class = firefox.webdriver.WebDriver + options = firefox.options.Options() + elif self._driver_type == "chrome": + driver_class = chrome.webdriver.WebDriver + options = chrome.options.Options() + arg: str = f"--window-size={self._window[0]},{self._window[1]}" + options.add_argument(arg) + else: + raise Exception(f"Webdriver name ({self._driver_type}) not supported") + # Prepare args for the webdriver init + options.add_argument("--headless") + kwargs: Dict = dict(options=options) + kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"]) + logger.info("Init selenium driver") + return driver_class(**kwargs) + + def auth(self, user: "User") -> WebDriver: + # Setting cookies requires doing a request first + driver = self.create() + driver.get(headless_url("/login/")) + return self._auth_func(driver, user) + + @staticmethod + def destroy(driver: WebDriver, tries=2): + """Destroy a driver""" + # This is some very flaky code in selenium. Hence the retries + # and catch-all exceptions + try: + retry_call(driver.close, tries=tries) + except Exception: # pylint: disable=broad-except + pass + try: + driver.quit() + except Exception: # pylint: disable=broad-except + pass + + def get_screenshot( + self, url: str, element_name: str, user: "User", retries: int = SELENIUM_RETRIES + ) -> Optional[bytes]: + driver = self.auth(user) + driver.set_window_size(*self._window) + driver.get(url) + img: Optional[bytes] = None + logger.debug(f"Sleeping for {SELENIUM_HEADSTART} seconds") + time.sleep(SELENIUM_HEADSTART) + try: + logger.debug(f"Wait for the presence of {element_name}") + element = WebDriverWait(driver, 10).until( + EC.presence_of_element_located((By.CLASS_NAME, element_name)) + ) + logger.debug(f"Wait for .loading to be done") + WebDriverWait(driver, 60).until_not( + EC.presence_of_all_elements_located((By.CLASS_NAME, "loading")) + ) + logger.info("Taking a PNG screenshot") + img = element.screenshot_as_png + except TimeoutException: + logger.error("Selenium timed out") + except WebDriverException as ex: + logger.error(ex) + # Some webdrivers do not support screenshots for elements. + # In such cases, take a screenshot of the entire page. + img = driver.screenshot() # pylint: disable=no-member + finally: + self.destroy(driver, retries) + return img + + +class BaseScreenshot: + driver_type = "chrome" + thumbnail_type: str = "" + element: str = "" + window_size: WindowSize = (800, 600) + thumb_size: WindowSize = (400, 300) + + def __init__(self, model_id: int): + self.model_id: int = model_id + self.screenshot: Optional[bytes] = None + self._driver = AuthWebDriverProxy(self.driver_type, self.window_size) + + @property + def cache_key(self) -> str: + return f"thumb__{self.thumbnail_type}__{self.model_id}" + + @property + def url(self) -> str: + raise NotImplementedError() + + def get_screenshot(self, user: "User") -> Optional[bytes]: + self.screenshot = self._driver.get_screenshot(self.url, self.element, user) + return self.screenshot + + def get( + self, + user: "User" = None, + cache: "Cache" = None, + thumb_size: Optional[WindowSize] = None, + ) -> Optional[BytesIO]: + """ + Get thumbnail screenshot has BytesIO from cache or fetch + + :param user: None to use current user or User Model to login and fetch + :param cache: The cache to use + :param thumb_size: Override thumbnail site + """ + payload: Optional[bytes] = None + thumb_size = thumb_size or self.thumb_size + if cache: + payload = cache.get(self.cache_key) + if not payload: + payload = self.compute_and_cache( + user=user, thumb_size=thumb_size, cache=cache + ) + else: + logger.info(f"Loaded thumbnail from cache: {self.cache_key}") + if payload: + return BytesIO(payload) + return None + + def get_from_cache(self, cache: "Cache") -> Optional[BytesIO]: + payload = cache.get(self.cache_key) + if payload: + return BytesIO(payload) + return None + + def compute_and_cache( # pylint: disable=too-many-arguments + self, + user: "User" = None, + thumb_size: Optional[WindowSize] = None, + cache: "Cache" = None, + force: bool = True, + ) -> Optional[bytes]: + """ + Fetches the screenshot, computes the thumbnail and caches the result + + :param user: If no user is given will use the current context + :param cache: The cache to keep the thumbnail payload + :param window_size: The window size from which will process the thumb + :param thumb_size: The final thumbnail size + :param force: Will force the computation even if it's already cached + :return: Image payload + """ + cache_key = self.cache_key + if not force and cache and cache.get(cache_key): + logger.info("Thumb already cached, skipping...") + return None + thumb_size = thumb_size or self.thumb_size + logger.info(f"Processing url for thumbnail: {cache_key}") + + payload = None + + # Assuming all sorts of things can go wrong with Selenium + try: + payload = self.get_screenshot(user=user) + except Exception as ex: # pylint: disable=broad-except + logger.error("Failed at generating thumbnail %s", ex) + + if payload and self.window_size != thumb_size: + try: + payload = self.resize_image(payload, thumb_size=thumb_size) + except Exception as ex: # pylint: disable=broad-except + logger.error("Failed at resizing thumbnail %s", ex) + payload = None + + if payload and cache: + logger.info(f"Caching thumbnail: {cache_key} {cache}") + cache.set(cache_key, payload) + return payload + + @classmethod + def resize_image( + cls, + img_bytes: bytes, + output: str = "png", + thumb_size: Optional[WindowSize] = None, + crop: bool = True, + ) -> bytes: + thumb_size = thumb_size or cls.thumb_size + img = Image.open(BytesIO(img_bytes)) + logger.debug(f"Selenium image size: {img.size}") + if crop and img.size[1] != cls.window_size[1]: + desired_ratio = float(cls.window_size[1]) / cls.window_size[0] + desired_width = int(img.size[0] * desired_ratio) + logger.debug(f"Cropping to: {img.size[0]}*{desired_width}") + img = img.crop((0, 0, img.size[0], desired_width)) + logger.debug(f"Resizing to {thumb_size}") + img = img.resize(thumb_size, Image.ANTIALIAS) + new_img = BytesIO() + if output != "png": + img = img.convert("RGB") + img.save(new_img, output) + new_img.seek(0) + return new_img.read() + + +class ChartScreenshot(BaseScreenshot): + thumbnail_type: str = "chart" + element: str = "chart-container" + window_size: WindowSize = (600, int(600 * 0.75)) + thumb_size: WindowSize = (300, int(300 * 0.75)) + + @property + def url(self) -> str: + return get_url_path("Superset.slice", slice_id=self.model_id, standalone="true") + + +class DashboardScreenshot(BaseScreenshot): + thumbnail_type: str = "dashboard" + element: str = "grid-container" + window_size: WindowSize = (1600, int(1600 * 0.75)) + thumb_size: WindowSize = (400, int(400 * 0.75)) + + @property + def url(self) -> str: + return get_url_path("Superset.dashboard", dashboard_id=self.model_id) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 1023f7a98d..4afe13bf3d 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -22,6 +22,8 @@ from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.filters import BaseFilter, Filters from flask_appbuilder.models.sqla.filters import FilterStartsWith +from superset.stats_logger import BaseStatsLogger + logger = logging.getLogger(__name__) get_related_schema = { "type": "object", @@ -57,6 +59,7 @@ class BaseSupersetModelRestApi(ModelRestApi): "bulk_delete": "delete", "info": "list", "related": "list", + "thumbnail": "list", "refresh": "edit", "data": "list", } @@ -88,9 +91,9 @@ class BaseSupersetModelRestApi(ModelRestApi): """ # pylint: disable=pointless-string-statement allowed_rel_fields: Set[str] = set() - def __init__(self): + def __init__(self) -> None: super().__init__() - self.stats_logger = None + self.stats_logger = BaseStatsLogger() def create_blueprint(self, appbuilder, *args, **kwargs): self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"] diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index ced4b3daca..03405b79e1 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -109,6 +109,7 @@ class DashboardApiTests(SupersetTestCase, ApiOwnersTestCaseMixin): "url": f"/superset/dashboard/slug1/", "slug": "slug1", "table_names": "", + "thumbnail_url": dashboard.thumbnail_url, } data = json.loads(rv.data.decode("utf-8")) self.assertIn("changed_on", data["result"]) diff --git a/tests/superset_test_config_thumbnails.py b/tests/superset_test_config_thumbnails.py new file mode 100644 index 0000000000..bf68df5e05 --- /dev/null +++ b/tests/superset_test_config_thumbnails.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# type: ignore +from copy import copy + +from flask import Flask +from werkzeug.contrib.cache import RedisCache + +from superset.config import * # type: ignore + +AUTH_USER_REGISTRATION_ROLE = "alpha" +SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db") +DEBUG = True +SUPERSET_WEBSERVER_PORT = 8081 + +# Allowing SQLALCHEMY_DATABASE_URI to be defined as an env var for +# continuous integration +if "SUPERSET__SQLALCHEMY_DATABASE_URI" in os.environ: + SQLALCHEMY_DATABASE_URI = os.environ["SUPERSET__SQLALCHEMY_DATABASE_URI"] + +SQL_SELECT_AS_CTA = True +SQL_MAX_ROW = 666 + + +def GET_FEATURE_FLAGS_FUNC(ff): + ff_copy = copy(ff) + ff_copy["super"] = "set" + return ff_copy + + +TESTING = True +WTF_CSRF_ENABLED = False +PUBLIC_ROLE_LIKE_GAMMA = True +AUTH_ROLE_PUBLIC = "Public" +EMAIL_NOTIFICATIONS = False + +CACHE_CONFIG = {"CACHE_TYPE": "simple"} + + +class CeleryConfig(object): + BROKER_URL = "redis://localhost" + CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks.thumbnails") + CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} + CONCURRENCY = 1 + + +CELERY_CONFIG = CeleryConfig + +FEATURE_FLAGS = { + "foo": "bar", + "KV_STORE": False, + "SHARE_QUERIES_VIA_KV_STORE": False, + "THUMBNAILS": True, + "THUMBNAILS_SQLA_LISTENERS": False, +} + + +def init_thumbnail_cache(app: Flask) -> RedisCache: + return RedisCache( + host="localhost", key_prefix="superset_thumbnails_", default_timeout=10000 + ) + + +THUMBNAIL_CACHE_CONFIG = init_thumbnail_cache diff --git a/tests/thumbnails_tests.py b/tests/thumbnails_tests.py new file mode 100644 index 0000000000..ac8f16b270 --- /dev/null +++ b/tests/thumbnails_tests.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# from superset import db +# from superset.models.dashboard import Dashboard +import subprocess +import urllib.request +from unittest import skipUnless +from unittest.mock import patch + +from flask_testing import LiveServerTestCase +from sqlalchemy.sql import func + +from superset import db, is_feature_enabled, security_manager, thumbnail_cache +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.utils.screenshots import ( + ChartScreenshot, + DashboardScreenshot, + get_auth_cookies, +) +from tests.test_app import app + +from .base_tests import SupersetTestCase + + +class CeleryStartMixin: + @classmethod + def setUpClass(cls): + with app.app_context(): + from werkzeug.contrib.cache import RedisCache + + class CeleryConfig(object): + BROKER_URL = "redis://localhost" + CELERY_IMPORTS = ("superset.tasks.thumbnails",) + CONCURRENCY = 1 + + app.config["CELERY_CONFIG"] = CeleryConfig + + def init_thumbnail_cache(app) -> RedisCache: + return RedisCache( + host="localhost", + key_prefix="superset_thumbnails_", + default_timeout=10000, + ) + + app.config["THUMBNAIL_CACHE_CONFIG"] = init_thumbnail_cache + + base_dir = app.config["BASE_DIR"] + worker_command = base_dir + "/bin/superset worker -w 2" + subprocess.Popen( + worker_command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + @classmethod + def tearDownClass(cls): + subprocess.call( + "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True + ) + subprocess.call( + "ps auxww | grep 'superset worker' | awk '{print $2}' | xargs kill -9", + shell=True, + ) + + +class ThumbnailsSeleniumLive(CeleryStartMixin, LiveServerTestCase): + def create_app(self): + return app + + def url_open_auth(self, username: str, url: str): + admin_user = security_manager.find_user(username=username) + cookies = {} + for cookie in get_auth_cookies(admin_user): + cookies["session"] = cookie + + opener = urllib.request.build_opener() + opener.addheaders.append(("Cookie", f"session={cookies['session']}")) + return opener.open(f"{self.get_server_url()}/{url}") + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_dashboard_screenshot(self): + """ + Thumbnails: Simple get async dashboard screenshot + """ + dashboard = db.session.query(Dashboard).all()[0] + with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get: + response = self.url_open_auth( + "admin", + f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/", + ) + self.assertEqual(response.getcode(), 202) + + +class ThumbnailsTests(CeleryStartMixin, SupersetTestCase): + + mock_image = b"bytes mock image" + + def test_dashboard_thumbnail_disabled(self): + """ + Thumbnails: Dashboard thumbnail disabled + """ + if is_feature_enabled("THUMBNAILS"): + return + dashboard = db.session.query(Dashboard).all()[0] + self.login(username="admin") + uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_chart_thumbnail_disabled(self): + """ + Thumbnails: Chart thumbnail disabled + """ + if is_feature_enabled("THUMBNAILS"): + return + chart = db.session.query(Slice).all()[0] + self.login(username="admin") + uri = f"api/v1/chart/{chart}/thumbnail/{chart.digest}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_dashboard_screenshot(self): + """ + Thumbnails: Simple get async dashboard screenshot + """ + dashboard = db.session.query(Dashboard).all()[0] + self.login(username="admin") + uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/" + with patch( + "superset.tasks.thumbnails.cache_dashboard_thumbnail.delay" + ) as mock_task: + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 202) + mock_task.assert_called_with(dashboard.id, force=True) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_dashboard_notfound(self): + """ + Thumbnails: Simple get async dashboard not found + """ + max_id = db.session.query(func.max(Dashboard.id)).scalar() + self.login(username="admin") + uri = f"api/v1/dashboard/{max_id + 1}/thumbnail/1234/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_dashboard_not_allowed(self): + """ + Thumbnails: Simple get async dashboard not allowed + """ + dashboard = db.session.query(Dashboard).all()[0] + self.login(username="gamma") + uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_chart_screenshot(self): + """ + Thumbnails: Simple get async chart screenshot + """ + chart = db.session.query(Slice).all()[0] + self.login(username="admin") + uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/" + with patch( + "superset.tasks.thumbnails.cache_chart_thumbnail.delay" + ) as mock_task: + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 202) + mock_task.assert_called_with(chart.id, force=True) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_async_chart_notfound(self): + """ + Thumbnails: Simple get async chart not found + """ + max_id = db.session.query(func.max(Slice.id)).scalar() + self.login(username="admin") + uri = f"api/v1/chart/{max_id + 1}/thumbnail/1234/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_cached_chart_wrong_digest(self): + """ + Thumbnails: Simple get chart with wrong digest + """ + chart = db.session.query(Slice).all()[0] + # Cache a test "image" + screenshot = ChartScreenshot(model_id=chart.id) + thumbnail_cache.set(screenshot.cache_key, self.mock_image) + self.login(username="admin") + uri = f"api/v1/chart/{chart.id}/thumbnail/1234/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 302) + self.assertRedirects(rv, f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/") + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_cached_dashboard_screenshot(self): + """ + Thumbnails: Simple get cached dashboard screenshot + """ + dashboard = db.session.query(Dashboard).all()[0] + # Cache a test "image" + screenshot = DashboardScreenshot(model_id=dashboard.id) + thumbnail_cache.set(screenshot.cache_key, self.mock_image) + self.login(username="admin") + uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + self.assertEqual(rv.data, self.mock_image) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_cached_chart_screenshot(self): + """ + Thumbnails: Simple get cached chart screenshot + """ + chart = db.session.query(Slice).all()[0] + # Cache a test "image" + screenshot = ChartScreenshot(model_id=chart.id) + thumbnail_cache.set(screenshot.cache_key, self.mock_image) + self.login(username="admin") + uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + self.assertEqual(rv.data, self.mock_image) + + @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") + def test_get_cached_dashboard_wrong_digest(self): + """ + Thumbnails: Simple get dashboard with wrong digest + """ + dashboard = db.session.query(Dashboard).all()[0] + # Cache a test "image" + screenshot = DashboardScreenshot(model_id=dashboard.id) + thumbnail_cache.set(screenshot.cache_key, self.mock_image) + self.login(username="admin") + uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/1234/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 302) + self.assertRedirects( + rv, f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/" + ) diff --git a/tox.ini b/tox.ini index 0a7dcfffaa..e37f488424 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,14 @@ setenv = whitelist_externals = npm +[testenv:thumbnails] +setenv = + SUPERSET_CONFIG = tests.superset_test_config_thumbnails +deps = + -rrequirements.txt + -rrequirements-dev.txt + .[postgres] + [testenv:black] commands = black --check setup.py superset tests