mirror of https://github.com/apache/superset.git
chore: Migrating reports to AuthWebdriverProxy (#10567)
* Migrating reports to AuthWebdriverProxy * Extracting out webdriver proxy / Adding thumbnail tests to CI * Adding license * Adding license again * Empty commit * Adding thumbnail tests to CI * Switching thumbnail test to Postgres * Linting * Adding mypy:ignore / removing thumbnail tests from CI * Putting ignore statement back * Updating docs * First cut at authprovider * First cut at authprovider mostly working - still needs more tests * Auth provider tests added * Linting * Linting again... * Linting again... * Busting CI cache * Reverting workflow change * Fixing dataclasses * Reverting back to master * linting? * Reverting installation.rst * Reverting package-lock.json * Addressing feedback * Blacking * Lazy logging strings * UPDATING.md note
This commit is contained in:
parent
8fb304d665
commit
2aaa4d92d9
|
@ -23,6 +23,8 @@ assists people when migrating to a new version.
|
|||
|
||||
## Next
|
||||
|
||||
* [10567](https://github.com/apache/incubator-superset/pull/10567): Default WEBDRIVER_OPTION_ARGS are Chrome-specific. If you're using FF, should be `--headless` only
|
||||
|
||||
* [10241](https://github.com/apache/incubator-superset/pull/10241): change on Alpha role, users started to have access to "Annotation Layers", "Css Templates" and "Import Dashboards".
|
||||
|
||||
* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`.
|
||||
|
|
|
@ -26,8 +26,8 @@ function reset_db() {
|
|||
echo --------------------
|
||||
echo Reseting test DB
|
||||
echo --------------------
|
||||
docker-compose stop superset-tests-worker
|
||||
RESET_DB_CMD="psql \"postgresql://superset:superset@127.0.0.1:5432\" <<-EOF
|
||||
docker-compose stop superset-tests-worker superset || true
|
||||
RESET_DB_CMD="psql \"postgresql://${DB_USER}:${DB_PASSWORD}@127.0.0.1:5432\" <<-EOF
|
||||
DROP DATABASE IF EXISTS ${DB_NAME};
|
||||
CREATE DATABASE ${DB_NAME};
|
||||
\\c ${DB_NAME}
|
||||
|
@ -53,10 +53,6 @@ function test_init() {
|
|||
echo Superset init
|
||||
echo --------------------
|
||||
superset init
|
||||
echo --------------------
|
||||
echo Load examples
|
||||
echo --------------------
|
||||
pytest -s tests/load_examples_test.py
|
||||
}
|
||||
|
||||
#
|
||||
|
@ -142,5 +138,5 @@ fi
|
|||
|
||||
if [ $RUN_TESTS -eq 1 ]
|
||||
then
|
||||
pytest -x -s --ignore=load_examples_test "${TEST_MODULE}"
|
||||
pytest -x -s "${TEST_MODULE}"
|
||||
fi
|
||||
|
|
|
@ -36,6 +36,7 @@ from superset.extensions import (
|
|||
db,
|
||||
feature_flag_manager,
|
||||
jinja_context_manager,
|
||||
machine_auth_provider_factory,
|
||||
manifest_processor,
|
||||
migrate,
|
||||
results_backend_manager,
|
||||
|
@ -468,6 +469,7 @@ class SupersetAppInitializer:
|
|||
self.configure_fab()
|
||||
self.configure_url_map_converters()
|
||||
self.configure_data_sources()
|
||||
self.configure_auth_provider()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
|
@ -499,6 +501,9 @@ class SupersetAppInitializer:
|
|||
|
||||
self.post_init()
|
||||
|
||||
def configure_auth_provider(self) -> None:
|
||||
machine_auth_provider_factory.init_app(self.flask_app)
|
||||
|
||||
def setup_event_logger(self) -> None:
|
||||
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
|
||||
self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
|
||||
|
|
|
@ -761,6 +761,11 @@ SLACK_PROXY = None
|
|||
# * Emails are sent using dry-run mode (logging only)
|
||||
SCHEDULED_EMAIL_DEBUG_MODE = False
|
||||
|
||||
# This auth provider is used by background (offline) tasks that need to access
|
||||
# protected resources. Can be overridden by end users in order to support
|
||||
# custom auth mechanisms
|
||||
MACHINE_AUTH_PROVIDER_CLASS = "superset.utils.machine_auth.MachineAuthProvider"
|
||||
|
||||
# Email reports - minimum time resolution (in minutes) for the crontab
|
||||
EMAIL_REPORTS_CRON_RESOLUTION = 15
|
||||
|
||||
|
@ -795,9 +800,22 @@ EMAIL_REPORTS_WEBDRIVER = "firefox"
|
|||
# Window size - this will impact the rendering of the data
|
||||
WEBDRIVER_WINDOW = {"dashboard": (1600, 2000), "slice": (3000, 1200)}
|
||||
|
||||
# An optional override to the default auth hook used to provide auth to the
|
||||
# offline webdriver
|
||||
WEBDRIVER_AUTH_FUNC = None
|
||||
|
||||
# Any config options to be passed as-is to the webdriver
|
||||
WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {}
|
||||
|
||||
# Additional args to be passed as arguments to the config object
|
||||
# Note: these options are Chrome-specific. For FF, these should
|
||||
# only include the "--headless" arg
|
||||
WEBDRIVER_OPTION_ARGS = [
|
||||
"--force-device-scale-factor=2.0",
|
||||
"--high-dpi-support=2.0",
|
||||
"--headless",
|
||||
]
|
||||
|
||||
# The base URL to query for accessing the user interface
|
||||
WEBDRIVER_BASEURL = "http://0.0.0.0:8080/"
|
||||
# The base URL for the email report hyperlinks.
|
||||
|
|
|
@ -34,6 +34,7 @@ from werkzeug.local import LocalProxy
|
|||
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
from superset.utils.feature_flag_manager import FeatureFlagManager
|
||||
from superset.utils.machine_auth import MachineAuthProviderFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.jinja_context import ( # pylint: disable=unused-import
|
||||
|
@ -139,6 +140,7 @@ _event_logger: Dict[str, Any] = {}
|
|||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||
feature_flag_manager = FeatureFlagManager()
|
||||
jinja_context_manager = JinjaContextManager()
|
||||
machine_auth_provider_factory = MachineAuthProviderFactory()
|
||||
manifest_processor = UIManifestProcessor(APP_DIR)
|
||||
migrate = Migrate()
|
||||
results_backend_manager = ResultsBackendManager()
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
|
@ -42,17 +41,16 @@ import pandas as pd
|
|||
import simplejson as json
|
||||
from celery.app.task import Task
|
||||
from dateutil.tz import tzlocal
|
||||
from flask import current_app, render_template, Response, session, url_for
|
||||
from flask import current_app, render_template, url_for
|
||||
from flask_babel import gettext as __
|
||||
from flask_login import login_user
|
||||
from retry.api import retry_call
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver import chrome, firefox
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
|
||||
from werkzeug.http import parse_cookie
|
||||
|
||||
from superset import app, db, security_manager, thumbnail_cache
|
||||
from superset.extensions import celery_app
|
||||
from superset.extensions import celery_app, machine_auth_provider_factory
|
||||
from superset.models.alerts import Alert, AlertLog
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -66,7 +64,7 @@ from superset.models.slice import Slice
|
|||
from superset.sql_parse import ParsedQuery
|
||||
from superset.tasks.slack_util import deliver_slack_msg
|
||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||
from superset.utils.screenshots import ChartScreenshot
|
||||
from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
@ -74,6 +72,7 @@ from superset.utils.urls import get_url_path
|
|||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
from werkzeug.datastructures import TypeConversionDict
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
||||
# Globals
|
||||
|
@ -191,27 +190,6 @@ def _generate_report_content(
|
|||
return ReportContent(body, data, images, slack_message, screenshot)
|
||||
|
||||
|
||||
def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
|
||||
# Login with the user specified to get the reports
|
||||
with app.test_request_context():
|
||||
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
||||
login_user(user)
|
||||
|
||||
# A mock response object to get the cookie information from
|
||||
response = Response()
|
||||
app.session_interface.save_session(app, session, response)
|
||||
|
||||
cookies = []
|
||||
|
||||
# Set the cookies in the driver
|
||||
for name, value in response.headers:
|
||||
if name.lower() == "set-cookie":
|
||||
cookie = parse_cookie(value)
|
||||
cookies.append(cookie["session"])
|
||||
|
||||
return cookies
|
||||
|
||||
|
||||
def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
|
||||
with app.test_request_context():
|
||||
base_url = (
|
||||
|
@ -220,44 +198,14 @@ def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
|
|||
return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs))
|
||||
|
||||
|
||||
def create_webdriver() -> Union[
|
||||
chrome.webdriver.WebDriver, firefox.webdriver.WebDriver
|
||||
]:
|
||||
# Create a webdriver for use in fetching reports
|
||||
if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox":
|
||||
driver_class = firefox.webdriver.WebDriver
|
||||
options = firefox.options.Options()
|
||||
elif config["EMAIL_REPORTS_WEBDRIVER"] == "chrome":
|
||||
driver_class = chrome.webdriver.WebDriver
|
||||
options = chrome.options.Options()
|
||||
def create_webdriver() -> WebDriver:
|
||||
return WebDriverProxy(driver_type=config["EMAIL_REPORTS_WEBDRIVER"]).auth(
|
||||
get_reports_user()
|
||||
)
|
||||
|
||||
options.add_argument("--headless")
|
||||
|
||||
# Prepare args for the webdriver init
|
||||
kwargs = dict(options=options)
|
||||
kwargs.update(config["WEBDRIVER_CONFIGURATION"])
|
||||
|
||||
# Initialize the driver
|
||||
driver = driver_class(**kwargs)
|
||||
|
||||
# Some webdrivers need an initial hit to the welcome URL
|
||||
# before we set the cookie
|
||||
welcome_url = _get_url_path("Superset.welcome")
|
||||
|
||||
# Hit the welcome URL and check if we were asked to login
|
||||
driver.get(welcome_url)
|
||||
elements = driver.find_elements_by_id("loginbox")
|
||||
|
||||
# This indicates that we were not prompted for a login box.
|
||||
if not elements:
|
||||
return driver
|
||||
|
||||
# Set the cookies in the driver
|
||||
for cookie in _get_auth_cookies():
|
||||
info = dict(name="session", value=cookie)
|
||||
driver.add_cookie(info)
|
||||
|
||||
return driver
|
||||
def get_reports_user() -> "User":
|
||||
return security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
||||
|
||||
|
||||
def destroy_webdriver(
|
||||
|
@ -364,12 +312,15 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte
|
|||
"Superset.slice", slice_id=slc.id, user_friendly=True
|
||||
)
|
||||
|
||||
cookies = {}
|
||||
for cookie in _get_auth_cookies():
|
||||
cookies["session"] = cookie
|
||||
# Login on behalf of the "reports" user in order to get cookies to deal with auth
|
||||
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
|
||||
get_reports_user()
|
||||
)
|
||||
# Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie"
|
||||
cookie_str = ";".join([f"{key}={val}" for key, val in auth_cookies.items()])
|
||||
|
||||
opener = urllib.request.build_opener()
|
||||
opener.addheaders.append(("Cookie", f"session={cookies['session']}"))
|
||||
opener.addheaders.append(("Cookie", cookie_str))
|
||||
response = opener.open(slice_url)
|
||||
if response.getcode() != 200:
|
||||
raise URLError(response.getcode())
|
||||
|
|
|
@ -18,18 +18,17 @@
|
|||
"""Utility functions used across Superset"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from superset import app, security_manager, thumbnail_cache
|
||||
from superset.extensions import celery_app
|
||||
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
|
||||
from superset.utils.webdriver import WindowSize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WindowSize = Tuple[int, int]
|
||||
|
||||
|
||||
@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
|
||||
def cache_chart_thumbnail(
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Callable, Dict, TYPE_CHECKING
|
||||
|
||||
from flask import current_app, Flask, request, Response, session
|
||||
from flask_login import login_user
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from werkzeug.http import parse_cookie
|
||||
|
||||
from superset.utils.urls import headless_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
||||
class MachineAuthProvider:
|
||||
def __init__(
|
||||
self, auth_webdriver_func_override: Callable[[WebDriver, "User"], WebDriver]
|
||||
):
|
||||
# This is here in order to allow for the authenticate_webdriver func to be
|
||||
# overridden via config, as opposed to the entire provider implementation
|
||||
self._auth_webdriver_func_override = auth_webdriver_func_override
|
||||
|
||||
def authenticate_webdriver(self, driver: WebDriver, user: "User",) -> WebDriver:
|
||||
"""
|
||||
Default AuthDriverFuncType type that sets a session cookie flask-login style
|
||||
:return: The WebDriver passed in (fluent)
|
||||
"""
|
||||
# Short-circuit this method if we have an override configured
|
||||
if self._auth_webdriver_func_override:
|
||||
return self._auth_webdriver_func_override(driver, user)
|
||||
|
||||
# Setting cookies requires doing a request first
|
||||
driver.get(headless_url("/login/"))
|
||||
|
||||
if user:
|
||||
cookies = self.get_auth_cookies(user)
|
||||
elif request.cookies:
|
||||
cookies = request.cookies
|
||||
else:
|
||||
cookies = {}
|
||||
|
||||
for cookie_name, cookie_val in cookies.items():
|
||||
driver.add_cookie(dict(name=cookie_name, value=cookie_val))
|
||||
|
||||
return driver
|
||||
|
||||
@staticmethod
|
||||
def get_auth_cookies(user: "User") -> Dict[str, str]:
|
||||
# Login with the user specified to get the reports
|
||||
with current_app.test_request_context("/login"):
|
||||
login_user(user)
|
||||
# A mock response object to get the cookie information from
|
||||
response = Response()
|
||||
current_app.session_interface.save_session(current_app, session, response)
|
||||
|
||||
cookies = {}
|
||||
|
||||
# Grab any "set-cookie" headers from the login response
|
||||
for name, value in response.headers:
|
||||
if name.lower() == "set-cookie":
|
||||
# This yields a MultiDict, which is ordered -- something like
|
||||
# MultiDict([('session', 'value-we-want), ('HttpOnly', ''), etc...
|
||||
# Therefore, we just need to grab the first tuple and add it to our
|
||||
# final dict
|
||||
cookie = parse_cookie(value)
|
||||
cookie_tuple = list(cookie.items())[0]
|
||||
cookies[cookie_tuple[0]] = cookie_tuple[1]
|
||||
|
||||
return cookies
|
||||
|
||||
|
||||
class MachineAuthProviderFactory:
|
||||
def __init__(self) -> None:
|
||||
self._auth_provider = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
auth_provider_fqclass = app.config["MACHINE_AUTH_PROVIDER_CLASS"]
|
||||
auth_provider_classname = auth_provider_fqclass[
|
||||
auth_provider_fqclass.rfind(".") + 1 :
|
||||
]
|
||||
auth_provider_module_name = auth_provider_fqclass[
|
||||
0 : auth_provider_fqclass.rfind(".")
|
||||
]
|
||||
auth_provider_class = getattr(
|
||||
importlib.import_module(auth_provider_module_name), auth_provider_classname
|
||||
)
|
||||
|
||||
self._auth_provider = auth_provider_class(app.config["WEBDRIVER_AUTH_FUNC"])
|
||||
|
||||
@property
|
||||
def instance(self) -> MachineAuthProvider:
|
||||
return self._auth_provider # type: ignore
|
|
@ -15,23 +15,13 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
from flask import current_app, request, Response, session
|
||||
from flask_login import login_user
|
||||
from retry.api import retry_call
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from selenium.webdriver import chrome, firefox
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from werkzeug.http import parse_cookie
|
||||
from flask import current_app
|
||||
|
||||
from superset.utils.hashing import md5_sha_from_dict
|
||||
from superset.utils.urls import headless_url
|
||||
from superset.utils.webdriver import WebDriverProxy, WindowSize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,140 +35,6 @@ if TYPE_CHECKING:
|
|||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_caching import Cache
|
||||
|
||||
# Time in seconds, we will wait for the page to load and render
|
||||
SELENIUM_CHECK_INTERVAL = 2
|
||||
SELENIUM_RETRIES = 5
|
||||
SELENIUM_HEADSTART = 3
|
||||
|
||||
WindowSize = Tuple[int, int]
|
||||
|
||||
|
||||
def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
|
||||
# Login with the user specified to get the reports
|
||||
with current_app.test_request_context("/login"):
|
||||
login_user(user)
|
||||
# A mock response object to get the cookie information from
|
||||
response = Response()
|
||||
current_app.session_interface.save_session(current_app, session, response)
|
||||
|
||||
cookies = []
|
||||
|
||||
# Set the cookies in the driver
|
||||
for name, value in response.headers:
|
||||
if name.lower() == "set-cookie":
|
||||
cookie = parse_cookie(value)
|
||||
cookies.append(cookie["session"])
|
||||
return cookies
|
||||
|
||||
|
||||
def auth_driver(driver: WebDriver, user: "User") -> WebDriver:
|
||||
"""
|
||||
Default AuthDriverFuncType type that sets a session cookie flask-login style
|
||||
:return: WebDriver
|
||||
"""
|
||||
if user:
|
||||
# Set the cookies in the driver
|
||||
for cookie in get_auth_cookies(user):
|
||||
info = dict(name="session", value=cookie)
|
||||
driver.add_cookie(info)
|
||||
elif request.cookies:
|
||||
cookies = request.cookies
|
||||
for k, v in cookies.items():
|
||||
cookie = dict(name=k, value=v)
|
||||
driver.add_cookie(cookie)
|
||||
return driver
|
||||
|
||||
|
||||
class AuthWebDriverProxy:
|
||||
def __init__(
|
||||
self,
|
||||
driver_type: str,
|
||||
window: Optional[WindowSize] = None,
|
||||
auth_func: Optional[
|
||||
Callable[..., Any]
|
||||
] = None, # pylint: disable=bad-whitespace
|
||||
):
|
||||
self._driver_type = driver_type
|
||||
self._window: WindowSize = window or (800, 600)
|
||||
config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
|
||||
self._auth_func = auth_func or config_auth_func
|
||||
|
||||
def create(self) -> WebDriver:
|
||||
if self._driver_type == "firefox":
|
||||
driver_class = firefox.webdriver.WebDriver
|
||||
options = firefox.options.Options()
|
||||
elif self._driver_type == "chrome":
|
||||
driver_class = chrome.webdriver.WebDriver
|
||||
options = chrome.options.Options()
|
||||
arg: str = f"--window-size={self._window[0]},{self._window[1]}"
|
||||
options.add_argument(arg)
|
||||
# TODO: 2 lines attempting retina PPI don't seem to be working
|
||||
options.add_argument("--force-device-scale-factor=2.0")
|
||||
options.add_argument("--high-dpi-support=2.0")
|
||||
else:
|
||||
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
|
||||
# Prepare args for the webdriver init
|
||||
options.add_argument("--headless")
|
||||
kwargs: Dict[Any, Any] = dict(options=options)
|
||||
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
|
||||
logger.info("Init selenium driver")
|
||||
return driver_class(**kwargs)
|
||||
|
||||
def auth(self, user: "User") -> WebDriver:
|
||||
# Setting cookies requires doing a request first
|
||||
driver = self.create()
|
||||
driver.get(headless_url("/login/"))
|
||||
return self._auth_func(driver, user)
|
||||
|
||||
@staticmethod
|
||||
def destroy(driver: WebDriver, tries: int = 2) -> None:
|
||||
"""Destroy a driver"""
|
||||
# This is some very flaky code in selenium. Hence the retries
|
||||
# and catch-all exceptions
|
||||
try:
|
||||
retry_call(driver.close, tries=tries)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
try:
|
||||
driver.quit()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
def get_screenshot(
|
||||
self,
|
||||
url: str,
|
||||
element_name: str,
|
||||
user: "User",
|
||||
retries: int = SELENIUM_RETRIES,
|
||||
) -> Optional[bytes]:
|
||||
driver = self.auth(user)
|
||||
driver.set_window_size(*self._window)
|
||||
driver.get(url)
|
||||
img: Optional[bytes] = None
|
||||
logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART)
|
||||
time.sleep(SELENIUM_HEADSTART)
|
||||
try:
|
||||
logger.debug("Wait for the presence of %s", element_name)
|
||||
element = WebDriverWait(
|
||||
driver, current_app.config["SCREENSHOT_LOCATE_WAIT"]
|
||||
).until(EC.presence_of_element_located((By.CLASS_NAME, element_name)))
|
||||
logger.debug("Wait for .loading to be done")
|
||||
WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not(
|
||||
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
|
||||
)
|
||||
logger.info("Taking a PNG screenshot")
|
||||
img = element.screenshot_as_png
|
||||
except TimeoutException:
|
||||
logger.error("Selenium timed out")
|
||||
except WebDriverException as ex:
|
||||
logger.error(ex)
|
||||
# Some webdrivers do not support screenshots for elements.
|
||||
# In such cases, take a screenshot of the entire page.
|
||||
img = driver.screenshot() # pylint: disable=no-member
|
||||
finally:
|
||||
self.destroy(driver, retries)
|
||||
return img
|
||||
|
||||
|
||||
class BaseScreenshot:
|
||||
driver_type = current_app.config.get("EMAIL_REPORTS_WEBDRIVER", "chrome")
|
||||
|
@ -192,9 +48,9 @@ class BaseScreenshot:
|
|||
self.url = url
|
||||
self.screenshot: Optional[bytes] = None
|
||||
|
||||
def driver(self, window_size: Optional[WindowSize] = None) -> AuthWebDriverProxy:
|
||||
def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy:
|
||||
window_size = window_size or self.window_size
|
||||
return AuthWebDriverProxy(self.driver_type, window_size)
|
||||
return WebDriverProxy(self.driver_type, window_size)
|
||||
|
||||
def cache_key(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from flask import current_app
|
||||
from retry.api import retry_call
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from selenium.webdriver import chrome, firefox
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
|
||||
from superset.extensions import machine_auth_provider_factory
|
||||
|
||||
WindowSize = Tuple[int, int]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Time in seconds, we will wait for the page to load and render
|
||||
SELENIUM_CHECK_INTERVAL = 2
|
||||
SELENIUM_RETRIES = 5
|
||||
SELENIUM_HEADSTART = 3
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
||||
class WebDriverProxy:
|
||||
def __init__(
|
||||
self, driver_type: str, window: Optional[WindowSize] = None,
|
||||
):
|
||||
self._driver_type = driver_type
|
||||
self._window: WindowSize = window or (800, 600)
|
||||
self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"]
|
||||
self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"]
|
||||
|
||||
def create(self) -> WebDriver:
|
||||
if self._driver_type == "firefox":
|
||||
driver_class = firefox.webdriver.WebDriver
|
||||
options = firefox.options.Options()
|
||||
elif self._driver_type == "chrome":
|
||||
driver_class = chrome.webdriver.WebDriver
|
||||
options = chrome.options.Options()
|
||||
options.add_argument(f"--window-size={self._window[0]},{self._window[1]}")
|
||||
else:
|
||||
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
|
||||
# Prepare args for the webdriver init
|
||||
|
||||
# Add additional configured options
|
||||
for arg in current_app.config["WEBDRIVER_OPTION_ARGS"]:
|
||||
options.add_argument(arg)
|
||||
|
||||
kwargs: Dict[Any, Any] = dict(options=options)
|
||||
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
|
||||
logger.info("Init selenium driver")
|
||||
|
||||
return driver_class(**kwargs)
|
||||
|
||||
def auth(self, user: "User") -> WebDriver:
|
||||
driver = self.create()
|
||||
return machine_auth_provider_factory.instance.authenticate_webdriver(
|
||||
driver, user
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def destroy(driver: WebDriver, tries: int = 2) -> None:
|
||||
"""Destroy a driver"""
|
||||
# This is some very flaky code in selenium. Hence the retries
|
||||
# and catch-all exceptions
|
||||
try:
|
||||
retry_call(driver.close, tries=tries)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
try:
|
||||
driver.quit()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
def get_screenshot(
|
||||
self,
|
||||
url: str,
|
||||
element_name: str,
|
||||
user: "User",
|
||||
retries: int = SELENIUM_RETRIES,
|
||||
) -> Optional[bytes]:
|
||||
driver = self.auth(user)
|
||||
driver.set_window_size(*self._window)
|
||||
driver.get(url)
|
||||
img: Optional[bytes] = None
|
||||
logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART)
|
||||
time.sleep(SELENIUM_HEADSTART)
|
||||
try:
|
||||
logger.debug("Wait for the presence of %s", element_name)
|
||||
element = WebDriverWait(driver, self._screenshot_locate_wait).until(
|
||||
EC.presence_of_element_located((By.CLASS_NAME, element_name))
|
||||
)
|
||||
logger.debug("Wait for .loading to be done")
|
||||
WebDriverWait(driver, self._screenshot_load_wait).until_not(
|
||||
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
|
||||
)
|
||||
logger.info("Taking a PNG screenshot or url %s", url)
|
||||
img = element.screenshot_as_png
|
||||
except TimeoutException:
|
||||
logger.error("Selenium timed out requesting url %s", url)
|
||||
except WebDriverException as ex:
|
||||
logger.error(ex)
|
||||
# Some webdrivers do not support screenshots for elements.
|
||||
# In such cases, take a screenshot of the entire page.
|
||||
img = driver.screenshot() # pylint: disable=no-member
|
||||
finally:
|
||||
self.destroy(driver, retries)
|
||||
return img
|
|
@ -100,6 +100,7 @@ class SupersetTestCase(TestCase):
|
|||
assert user_to_create
|
||||
user_to_create.roles = [security_manager.find_role(r) for r in roles]
|
||||
db.session.commit()
|
||||
return user_to_create
|
||||
|
||||
@staticmethod
|
||||
def create_user(
|
||||
|
|
|
@ -40,8 +40,7 @@ from superset.tasks.schedules import (
|
|||
)
|
||||
from superset.models.slice import Slice
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
from .utils import read_fixture
|
||||
from tests.utils import read_fixture
|
||||
|
||||
|
||||
class TestSchedules(SupersetTestCase):
|
||||
|
@ -172,7 +171,6 @@ class TestSchedules(SupersetTestCase):
|
|||
mock_driver_class.return_value = mock_driver
|
||||
mock_driver.find_elements_by_id.side_effect = [True, False]
|
||||
|
||||
create_webdriver()
|
||||
create_webdriver()
|
||||
mock_driver.add_cookie.assert_called_once()
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
# under the License.
|
||||
# from superset import db
|
||||
# from superset.models.dashboard import Dashboard
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from unittest import skipUnless
|
||||
from unittest.mock import patch
|
||||
|
@ -24,15 +23,11 @@ from unittest.mock import patch
|
|||
from flask_testing import LiveServerTestCase
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
import tests.test_app
|
||||
from superset import db, is_feature_enabled, security_manager, thumbnail_cache
|
||||
from superset.extensions import machine_auth_provider_factory
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.screenshots import (
|
||||
ChartScreenshot,
|
||||
DashboardScreenshot,
|
||||
get_auth_cookies,
|
||||
)
|
||||
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
|
||||
from superset.utils.urls import get_url_path
|
||||
from tests.test_app import app
|
||||
|
||||
|
@ -45,10 +40,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase):
|
|||
|
||||
def url_open_auth(self, username: str, url: str):
|
||||
admin_user = security_manager.find_user(username=username)
|
||||
cookies = {}
|
||||
for cookie in get_auth_cookies(admin_user):
|
||||
cookies["session"] = cookie
|
||||
|
||||
cookies = machine_auth_provider_factory.instance.get_auth_cookies(admin_user)
|
||||
opener = urllib.request.build_opener()
|
||||
opener.addheaders.append(("Cookie", f"session={cookies['session']}"))
|
||||
return opener.open(f"{self.get_server_url()}/{url}")
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,56 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import call, Mock, patch
|
||||
|
||||
from superset.extensions import machine_auth_provider_factory
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class MachineAuthProviderTests(SupersetTestCase):
|
||||
def test_get_auth_cookies(self):
|
||||
user = self.get_user("admin")
|
||||
auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user)
|
||||
self.assertIsNotNone(auth_cookies["session"])
|
||||
|
||||
@patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies")
|
||||
def test_auth_driver_user(self, get_auth_cookies):
|
||||
user = self.get_user("admin")
|
||||
driver = Mock()
|
||||
get_auth_cookies.return_value = {
|
||||
"session": "session_val",
|
||||
"other_cookie": "other_val",
|
||||
}
|
||||
machine_auth_provider_factory.instance.authenticate_webdriver(driver, user)
|
||||
driver.add_cookie.assert_has_calls(
|
||||
[
|
||||
call({"name": "session", "value": "session_val"}),
|
||||
call({"name": "other_cookie", "value": "other_val"}),
|
||||
]
|
||||
)
|
||||
|
||||
@patch("superset.utils.machine_auth.request")
|
||||
def test_auth_driver_request(self, request):
|
||||
driver = Mock()
|
||||
request.cookies = {"session": "session_val", "other_cookie": "other_val"}
|
||||
machine_auth_provider_factory.instance.authenticate_webdriver(driver, None)
|
||||
driver.add_cookie.assert_has_calls(
|
||||
[
|
||||
call({"name": "session", "value": "session_val"}),
|
||||
call({"name": "other_cookie", "value": "other_val"}),
|
||||
]
|
||||
)
|
Loading…
Reference in New Issue