fix(celery cache warmup): add auth and use warm_up_cache endpoint (#21076)

This commit is contained in:
ʈᵃᵢ 2022-08-30 09:24:24 -07:00 committed by GitHub
parent b354f2265a
commit 04dd8d414d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 181 deletions

View File

@ -69,6 +69,16 @@ REDIS_RESULTS_DB = get_env_variable("REDIS_RESULTS_DB", "1")
RESULTS_BACKEND = FileSystemCache("/app/superset_home/sqllab")
CACHE_CONFIG = {
"CACHE_TYPE": "redis",
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_KEY_PREFIX": "superset_",
"CACHE_REDIS_HOST": REDIS_HOST,
"CACHE_REDIS_PORT": REDIS_PORT,
"CACHE_REDIS_DB": REDIS_RESULTS_DB,
}
DATA_CACHE_CONFIG = CACHE_CONFIG
class CeleryConfig(object):
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"

View File

@ -14,73 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional, Union
from urllib import request
from urllib.error import URLError
from celery.beat import SchedulingError
from celery.utils.log import get_task_logger
from sqlalchemy import and_, func
from superset import app, db
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime
from superset.views.utils import build_extra_filters
from superset.utils.machine_auth import MachineAuthProvider
logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)
def get_form_data(
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.
When a dashboard has `default_filters` they need to be added as extra
filters in the GET request for charts.
"""
form_data: Dict[str, Any] = {"slice_id": chart_id}
if dashboard is None or not dashboard.json_metadata:
return form_data
json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
if not default_filters:
return form_data
filter_scopes = json_metadata.get("filter_scopes", {})
layout = json.loads(dashboard.position_json or "{}")
if (
isinstance(layout, dict)
and isinstance(filter_scopes, dict)
and isinstance(default_filters, dict)
):
extra_filters = build_extra_filters(
layout, filter_scopes, default_filters, chart_id
)
if extra_filters:
form_data["extra_filters"] = extra_filters
return form_data
def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
def get_url(chart: Slice, dashboard: Optional[Dashboard] = None) -> str:
"""Return external URL for warming up a given chart/table cache."""
with app.test_request_context():
baseurl = (
"{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**app.config)
)
return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
baseurl = "{WEBDRIVER_BASEURL}".format(**app.config)
url = f"{baseurl}superset/warm_up_cache/?slice_id={chart.id}"
if dashboard:
url += f"&dashboard_id={dashboard.id}"
return url
class Strategy: # pylint: disable=too-few-public-methods
@ -179,8 +142,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
urls.append(get_url(chart, form_data_with_filters))
urls.append(get_url(chart, dashboard))
return urls
@ -253,6 +215,30 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="fetch_url")
def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
"""
Celery job to fetch url
"""
result = {}
try:
logger.info("Fetching %s", url)
req = request.Request(url, headers=headers)
response = request.urlopen( # pylint: disable=consider-using-with
req, timeout=600
)
logger.info("Fetched %s, status code: %s", url, response.code)
if response.code == 200:
result = {"success": url, "response": response.read().decode("utf-8")}
else:
result = {"error": url, "status_code": response.code}
logger.error("Error fetching %s, status code: %s", url, response.code)
except URLError as err:
logger.exception("Error warming up cache!")
result = {"error": url, "exception": str(err)}
return result
@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
@ -282,14 +268,18 @@ def cache_warmup(
logger.exception(message)
return message
results: Dict[str, List[str]] = {"success": [], "errors": []}
user = security_manager.get_user_by_username(app.config["THUMBNAIL_SELENIUM_USER"])
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {"Cookie": f"session={cookies.get('session', '')}"}
results: Dict[str, List[str]] = {"scheduled": [], "errors": []}
for url in strategy.get_urls():
try:
logger.info("Fetching %s", url)
request.urlopen(url) # pylint: disable=consider-using-with
results["success"].append(url)
except URLError:
logger.exception("Error warming up cache!")
logger.info("Scheduling %s", url)
fetch_url.delay(url, headers)
results["scheduled"].append(url)
except SchedulingError:
logger.exception("Error scheduling fetch_url: %s", url)
results["errors"].append(url)
return results

View File

@ -38,9 +38,9 @@ from superset.models.core import Log
from superset.models.tags import get_tag, ObjectTypes, TaggedObject, TagTypes
from superset.tasks.cache import (
DashboardTagsStrategy,
get_form_data,
TopNDashboardsStrategy,
)
from superset.utils.urls import get_url_host
from .base_tests import SupersetTestCase
from .dashboard_utils import create_dashboard, create_slice, create_table_metadata
@ -49,7 +49,6 @@ from .fixtures.unicode_dashboard import (
load_unicode_data,
)
URL_PREFIX = "http://0.0.0.0:8081"
mock_positions = {
"DASHBOARD_VERSION_KEY": "v2",
@ -69,128 +68,6 @@ mock_positions = {
class TestCacheWarmUp(SupersetTestCase):
def test_get_form_data_chart_only(self):
chart_id = 1
result = get_form_data(chart_id, None)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_no_dashboard_metadata(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = None
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_immune_slice(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"filter_scopes": {
str(filter_box_id): {
"name": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
"default_filters": json.dumps(
{str(filter_box_id): {"name": ["Alice", "Bob"]}}
),
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_no_default_filters(self):
chart_id = 1
dashboard = MagicMock()
dashboard.json_metadata = json.dumps({})
dashboard.position_json = json.dumps(mock_positions)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data_immune_fields(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [{"col": "name", "op": "in", "val": ["Alice", "Bob"]}],
}
self.assertEqual(result, expected)
def test_get_form_data_no_extra_filters(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{str(filter_box_id): {"__time_range": "100 years ago : today"}}
),
"filter_scopes": {
str(filter_box_id): {
"__time_range": {"scope": ["ROOT_ID"], "immune": [chart_id]}
}
},
}
)
result = get_form_data(chart_id, dashboard)
expected = {"slice_id": chart_id}
self.assertEqual(result, expected)
def test_get_form_data(self):
chart_id = 1
filter_box_id = 2
dashboard = MagicMock()
dashboard.position_json = json.dumps(mock_positions)
dashboard.json_metadata = json.dumps(
{
"default_filters": json.dumps(
{
str(filter_box_id): {
"name": ["Alice", "Bob"],
"__time_range": "100 years ago : today",
}
}
)
}
)
result = get_form_data(chart_id, dashboard)
expected = {
"slice_id": chart_id,
"extra_filters": [
{"col": "name", "op": "in", "val": ["Alice", "Bob"]},
{"col": "__time_range", "op": "==", "val": "100 years ago : today"},
],
}
self.assertEqual(result, expected)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_top_n_dashboards_strategy(self):
# create a top visited dashboard
@ -202,7 +79,12 @@ class TestCacheWarmUp(SupersetTestCase):
strategy = TopNDashboardsStrategy(1)
result = sorted(strategy.get_urls())
expected = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
expected = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}&dashboard_id={dash.id}"
for slc in dash.slices
]
)
self.assertEqual(result, expected)
def reset_tag(self, tag):
@ -228,7 +110,12 @@ class TestCacheWarmUp(SupersetTestCase):
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tag1_urls = sorted(
[
f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"
for slc in dash.slices
]
)
tagged_object = TaggedObject(
tag_id=tag1.id, object_id=dash.id, object_type=ObjectTypes.dashboard
)
@ -248,7 +135,7 @@ class TestCacheWarmUp(SupersetTestCase):
# tag first slice
dash = self.get_dash_by_slug("unicode-test")
slc = dash.slices[0]
tag2_urls = [f"{URL_PREFIX}{slc.url}"]
tag2_urls = [f"{get_url_host()}superset/warm_up_cache/?slice_id={slc.id}"]
object_id = slc.id
tagged_object = TaggedObject(
tag_id=tag2.id, object_id=object_id, object_type=ObjectTypes.chart

View File

@ -73,6 +73,8 @@ FEATURE_FLAGS = {
"DRILL_TO_DETAIL": True,
}
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
def GET_FEATURE_FLAGS_FUNC(ff):
ff_copy = copy(ff)