mirror of https://github.com/apache/superset.git
* First cut at app factory * Setting things back to master * Working with new FLASK_APP * Still need to refactor Celery * CLI mostly working * Working on unit tests * Moving cli stuff around a bit * Removing get in config * Defaulting test config * Adding flask-testing * flask-testing casing * resultsbackend property bug * Fixing up cli * Quick fix for KV api * Working on save slice * Fixed core_tests * Fixed utils_tests * Most tests working - still need to dig into remaining app_context issue in tests * All tests passing locally - need to update code comments * Fixing dashboard tests again * Blacking * Sorting imports * linting * removing envvar mangling * blacking * Fixing unit tests * isorting * licensing * fixing mysql tests * fixing cypress? * fixing .flaskenv * fixing test app_ctx * fixing cypress * moving manifest processor around * moving results backend manager around * Cleaning up __init__ a bit more * Addressing PR comments * Addressing PR comments * Blacking * Fixes for running celery worker * Tuning isort * Blacking
This commit is contained in:
parent
300c4ecb0f
commit
e490414484
|
@ -14,5 +14,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
FLASK_APP=superset:app
|
||||
FLASK_ENV=development
|
||||
FLASK_APP="superset.app:create_app()"
|
||||
FLASK_ENV="development"
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
black==19.3b0
|
||||
coverage==4.5.3
|
||||
flask-cors==3.0.7
|
||||
flask-testing==0.7.1
|
||||
ipdb==0.12
|
||||
isort==4.3.21
|
||||
mypy==0.670
|
||||
|
|
1
setup.py
1
setup.py
|
@ -128,4 +128,5 @@ setup(
|
|||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
],
|
||||
tests_require=["flask-testing==0.7.1"],
|
||||
)
|
||||
|
|
|
@ -14,229 +14,38 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
"""Package's main module!"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
from flask import current_app, Flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
import wtforms_json
|
||||
from flask import Flask, redirect
|
||||
from flask_appbuilder import AppBuilder, IndexView, SQLA
|
||||
from flask_appbuilder.baseviews import expose
|
||||
from flask_compress import Compress
|
||||
from flask_migrate import Migrate
|
||||
from flask_talisman import Talisman
|
||||
from flask_wtf.csrf import CSRFProtect
|
||||
|
||||
from superset import config
|
||||
from superset.app import create_app
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import (
|
||||
appbuilder,
|
||||
cache_manager,
|
||||
db,
|
||||
event_logger,
|
||||
feature_flag_manager,
|
||||
manifest_processor,
|
||||
results_backend_manager,
|
||||
security_manager,
|
||||
talisman,
|
||||
)
|
||||
from superset.security import SupersetSecurityManager
|
||||
from superset.utils.core import pessimistic_connection_handling, setup_cache
|
||||
from superset.utils.log import get_event_logger_from_cfg_value
|
||||
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
|
||||
|
||||
wtforms_json.init()
|
||||
|
||||
APP_DIR = os.path.dirname(__file__)
|
||||
CONFIG_MODULE = os.environ.get("SUPERSET_CONFIG", "superset.config")
|
||||
|
||||
if not os.path.exists(config.DATA_DIR):
|
||||
os.makedirs(config.DATA_DIR)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(CONFIG_MODULE) # type: ignore
|
||||
conf = app.config
|
||||
|
||||
#################################################################
|
||||
# Handling manifest file logic at app start
|
||||
#################################################################
|
||||
MANIFEST_FILE = APP_DIR + "/static/assets/dist/manifest.json"
|
||||
manifest: Dict[Any, Any] = {}
|
||||
|
||||
|
||||
def parse_manifest_json():
|
||||
global manifest
|
||||
try:
|
||||
with open(MANIFEST_FILE, "r") as f:
|
||||
# the manifest inclues non-entry files
|
||||
# we only need entries in templates
|
||||
full_manifest = json.load(f)
|
||||
manifest = full_manifest.get("entrypoints", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_js_manifest_files(filename):
|
||||
if app.debug:
|
||||
parse_manifest_json()
|
||||
entry_files = manifest.get(filename, {})
|
||||
return entry_files.get("js", [])
|
||||
|
||||
|
||||
def get_css_manifest_files(filename):
|
||||
if app.debug:
|
||||
parse_manifest_json()
|
||||
entry_files = manifest.get(filename, {})
|
||||
return entry_files.get("css", [])
|
||||
|
||||
|
||||
def get_unloaded_chunks(files, loaded_chunks):
|
||||
filtered_files = [f for f in files if f not in loaded_chunks]
|
||||
for f in filtered_files:
|
||||
loaded_chunks.add(f)
|
||||
return filtered_files
|
||||
|
||||
|
||||
parse_manifest_json()
|
||||
|
||||
|
||||
@app.context_processor
|
||||
def get_manifest():
|
||||
return dict(
|
||||
loaded_chunks=set(),
|
||||
get_unloaded_chunks=get_unloaded_chunks,
|
||||
js_manifest=get_js_manifest_files,
|
||||
css_manifest=get_css_manifest_files,
|
||||
)
|
||||
|
||||
|
||||
#################################################################
|
||||
|
||||
for bp in conf["BLUEPRINTS"]:
|
||||
try:
|
||||
print("Registering blueprint: '{}'".format(bp.name))
|
||||
app.register_blueprint(bp)
|
||||
except Exception as e:
|
||||
print("blueprint registration failed")
|
||||
logging.exception(e)
|
||||
|
||||
if conf.get("SILENCE_FAB"):
|
||||
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
|
||||
|
||||
db = SQLA(app)
|
||||
|
||||
if conf.get("WTF_CSRF_ENABLED"):
|
||||
csrf = CSRFProtect(app)
|
||||
csrf_exempt_list = conf.get("WTF_CSRF_EXEMPT_LIST", [])
|
||||
for ex in csrf_exempt_list:
|
||||
csrf.exempt(ex)
|
||||
|
||||
pessimistic_connection_handling(db.engine)
|
||||
|
||||
cache = setup_cache(app, conf.get("CACHE_CONFIG"))
|
||||
tables_cache = setup_cache(app, conf.get("TABLE_NAMES_CACHE_CONFIG"))
|
||||
|
||||
migrate = Migrate(app, db, directory=APP_DIR + "/migrations")
|
||||
|
||||
app.config["LOGGING_CONFIGURATOR"].configure_logging(app.config, app.debug)
|
||||
|
||||
if app.config["ENABLE_CORS"]:
|
||||
from flask_cors import CORS
|
||||
|
||||
CORS(app, **app.config["CORS_OPTIONS"])
|
||||
|
||||
if app.config["ENABLE_PROXY_FIX"]:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix( # type: ignore
|
||||
app.wsgi_app, **app.config["PROXY_FIX_CONFIG"]
|
||||
)
|
||||
|
||||
if app.config["ENABLE_CHUNK_ENCODING"]:
|
||||
|
||||
class ChunkedEncodingFix(object):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
|
||||
# content-length and read the stream till the end.
|
||||
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == u"chunked":
|
||||
environ["wsgi.input_terminated"] = True
|
||||
return self.app(environ, start_response)
|
||||
|
||||
app.wsgi_app = ChunkedEncodingFix(app.wsgi_app) # type: ignore
|
||||
|
||||
if app.config["UPLOAD_FOLDER"]:
|
||||
try:
|
||||
os.makedirs(app.config["UPLOAD_FOLDER"])
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
for middleware in app.config["ADDITIONAL_MIDDLEWARE"]:
|
||||
app.wsgi_app = middleware(app.wsgi_app) # type: ignore
|
||||
|
||||
|
||||
class MyIndexView(IndexView):
|
||||
@expose("/")
|
||||
def index(self):
|
||||
return redirect("/superset/welcome")
|
||||
|
||||
|
||||
custom_sm = app.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager
|
||||
if not issubclass(custom_sm, SupersetSecurityManager):
|
||||
raise Exception(
|
||||
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
|
||||
not FAB's security manager.
|
||||
See [4565] in UPDATING.md"""
|
||||
)
|
||||
|
||||
with app.app_context():
|
||||
appbuilder = AppBuilder(
|
||||
app,
|
||||
db.session,
|
||||
base_template="superset/base.html",
|
||||
indexview=MyIndexView,
|
||||
security_manager_class=custom_sm,
|
||||
update_perms=False, # Run `superset init` to update FAB's perms
|
||||
)
|
||||
|
||||
security_manager = appbuilder.sm
|
||||
|
||||
results_backend = app.config["RESULTS_BACKEND"]
|
||||
results_backend_use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
|
||||
|
||||
# Merge user defined feature flags with default feature flags
|
||||
_feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
|
||||
_feature_flags.update(app.config["FEATURE_FLAGS"])
|
||||
|
||||
# Event Logger
|
||||
event_logger = get_event_logger_from_cfg_value(app.config["EVENT_LOGGER"])
|
||||
|
||||
|
||||
def get_feature_flags():
|
||||
GET_FEATURE_FLAGS_FUNC = app.config["GET_FEATURE_FLAGS_FUNC"]
|
||||
if GET_FEATURE_FLAGS_FUNC:
|
||||
return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags))
|
||||
return _feature_flags
|
||||
|
||||
|
||||
def is_feature_enabled(feature):
|
||||
"""Utility function for checking whether a feature is turned on"""
|
||||
return get_feature_flags().get(feature)
|
||||
|
||||
|
||||
# Flask-Compress
|
||||
if conf.get("ENABLE_FLASK_COMPRESS"):
|
||||
Compress(app)
|
||||
|
||||
|
||||
talisman = Talisman()
|
||||
|
||||
if app.config["TALISMAN_ENABLED"]:
|
||||
talisman.init_app(app, **app.config["TALISMAN_CONFIG"])
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
flask_app_mutator = app.config["FLASK_APP_MUTATOR"]
|
||||
if flask_app_mutator:
|
||||
flask_app_mutator(app)
|
||||
|
||||
from superset import views # noqa isort:skip
|
||||
|
||||
# Registering sources
|
||||
module_datasource_map = app.config["DEFAULT_MODULE_DS_MAP"]
|
||||
module_datasource_map.update(app.config["ADDITIONAL_MODULE_DS_MAP"])
|
||||
ConnectorRegistry.register_sources(module_datasource_map)
|
||||
# All of the fields located here should be considered legacy. The correct way
|
||||
# to declare "global" dependencies is to define it in extensions.py,
|
||||
# then initialize it in app.create_app(). These fields will be removed
|
||||
# in subsequent PRs as things are migrated towards the factory pattern
|
||||
app: Flask = current_app
|
||||
cache = LocalProxy(lambda: cache_manager.cache)
|
||||
conf = LocalProxy(lambda: current_app.config)
|
||||
get_feature_flags = feature_flag_manager.get_feature_flags
|
||||
get_css_manifest_files = manifest_processor.get_css_manifest_files
|
||||
is_feature_enabled = feature_flag_manager.is_feature_enabled
|
||||
results_backend = LocalProxy(lambda: results_backend_manager.results_backend)
|
||||
results_backend_use_msgpack = LocalProxy(
|
||||
lambda: results_backend_manager.should_use_msgpack
|
||||
)
|
||||
tables_cache = LocalProxy(lambda: cache_manager.tables_cache)
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import wtforms_json
|
||||
from flask import Flask, redirect
|
||||
from flask_appbuilder import expose, IndexView
|
||||
from flask_compress import Compress
|
||||
from flask_wtf import CSRFProtect
|
||||
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import (
|
||||
_event_logger,
|
||||
APP_DIR,
|
||||
appbuilder,
|
||||
cache_manager,
|
||||
celery_app,
|
||||
db,
|
||||
feature_flag_manager,
|
||||
manifest_processor,
|
||||
migrate,
|
||||
results_backend_manager,
|
||||
talisman,
|
||||
)
|
||||
from superset.security import SupersetSecurityManager
|
||||
from superset.utils.core import pessimistic_connection_handling
|
||||
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app():
|
||||
app = Flask(__name__)
|
||||
|
||||
try:
|
||||
# Allow user to override our config completely
|
||||
config_module = os.environ.get("SUPERSET_CONFIG", "superset.config")
|
||||
app.config.from_object(config_module)
|
||||
|
||||
app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app)
|
||||
app_initializer.init_app()
|
||||
|
||||
return app
|
||||
|
||||
# Make sure that bootstrap errors ALWAYS get logged
|
||||
except Exception as ex:
|
||||
logger.exception("Failed to create app")
|
||||
raise ex
|
||||
|
||||
|
||||
class SupersetIndexView(IndexView):
|
||||
@expose("/")
|
||||
def index(self):
|
||||
return redirect("/superset/welcome")
|
||||
|
||||
|
||||
class SupersetAppInitializer:
|
||||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.flask_app = app
|
||||
self.config = app.config
|
||||
self.manifest: dict = {}
|
||||
|
||||
def pre_init(self) -> None:
|
||||
"""
|
||||
Called after all other init tasks are complete
|
||||
"""
|
||||
wtforms_json.init()
|
||||
|
||||
if not os.path.exists(self.config["DATA_DIR"]):
|
||||
os.makedirs(self.config["DATA_DIR"])
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""
|
||||
Called before any other init tasks
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_celery(self) -> None:
|
||||
celery_app.config_from_object(self.config["CELERY_CONFIG"])
|
||||
celery_app.set_default()
|
||||
|
||||
@staticmethod
|
||||
def init_views() -> None:
|
||||
# TODO - This should iterate over all views and register them with FAB...
|
||||
from superset import views # noqa pylint: disable=unused-variable
|
||||
|
||||
def init_app_in_ctx(self) -> None:
|
||||
"""
|
||||
Runs init logic in the context of the app
|
||||
"""
|
||||
self.configure_feature_flags()
|
||||
self.configure_fab()
|
||||
self.configure_data_sources()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
|
||||
if flask_app_mutator:
|
||||
flask_app_mutator(self.flask_app)
|
||||
|
||||
self.init_views()
|
||||
|
||||
def init_app(self) -> None:
|
||||
"""
|
||||
Main entry point which will delegate to other methods in
|
||||
order to fully init the app
|
||||
"""
|
||||
self.pre_init()
|
||||
|
||||
self.setup_db()
|
||||
|
||||
self.configure_celery()
|
||||
|
||||
self.setup_event_logger()
|
||||
|
||||
self.setup_bundle_manifest()
|
||||
|
||||
self.register_blueprints()
|
||||
|
||||
self.configure_wtf()
|
||||
|
||||
self.configure_logging()
|
||||
|
||||
self.configure_middlewares()
|
||||
|
||||
self.configure_cache()
|
||||
|
||||
with self.flask_app.app_context():
|
||||
self.init_app_in_ctx()
|
||||
|
||||
self.post_init()
|
||||
|
||||
def setup_event_logger(self):
|
||||
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
|
||||
self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
|
||||
)
|
||||
|
||||
def configure_data_sources(self):
|
||||
# Registering sources
|
||||
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
|
||||
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
|
||||
ConnectorRegistry.register_sources(module_datasource_map)
|
||||
|
||||
def configure_cache(self):
|
||||
cache_manager.init_app(self.flask_app)
|
||||
results_backend_manager.init_app(self.flask_app)
|
||||
|
||||
def configure_feature_flags(self):
|
||||
feature_flag_manager.init_app(self.flask_app)
|
||||
|
||||
def configure_fab(self):
|
||||
if self.config["SILENCE_FAB"]:
|
||||
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
|
||||
|
||||
custom_sm = self.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager
|
||||
if not issubclass(custom_sm, SupersetSecurityManager):
|
||||
raise Exception(
|
||||
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
|
||||
not FAB's security manager.
|
||||
See [4565] in UPDATING.md"""
|
||||
)
|
||||
|
||||
appbuilder.indexview = SupersetIndexView
|
||||
appbuilder.base_template = "superset/base.html"
|
||||
appbuilder.security_manager_class = custom_sm
|
||||
appbuilder.update_perms = False
|
||||
appbuilder.init_app(self.flask_app, db.session)
|
||||
|
||||
def configure_middlewares(self):
|
||||
if self.config["ENABLE_CORS"]:
|
||||
from flask_cors import CORS
|
||||
|
||||
CORS(self.flask_app, **self.config["CORS_OPTIONS"])
|
||||
|
||||
if self.config["ENABLE_PROXY_FIX"]:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
self.flask_app.wsgi_app = ProxyFix(
|
||||
self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"]
|
||||
)
|
||||
|
||||
if self.config["ENABLE_CHUNK_ENCODING"]:
|
||||
|
||||
class ChunkedEncodingFix(object): # pylint: disable=too-few-public-methods
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
|
||||
# content-length and read the stream till the end.
|
||||
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked":
|
||||
environ["wsgi.input_terminated"] = True
|
||||
return self.app(environ, start_response)
|
||||
|
||||
self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app)
|
||||
|
||||
if self.config["UPLOAD_FOLDER"]:
|
||||
try:
|
||||
os.makedirs(self.config["UPLOAD_FOLDER"])
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
|
||||
self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app)
|
||||
|
||||
# Flask-Compress
|
||||
if self.config["ENABLE_FLASK_COMPRESS"]:
|
||||
Compress(self.flask_app)
|
||||
|
||||
if self.config["TALISMAN_ENABLED"]:
|
||||
talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"])
|
||||
|
||||
def configure_logging(self):
|
||||
self.config["LOGGING_CONFIGURATOR"].configure_logging(
|
||||
self.config, self.flask_app.debug
|
||||
)
|
||||
|
||||
def setup_db(self):
|
||||
db.init_app(self.flask_app)
|
||||
|
||||
with self.flask_app.app_context():
|
||||
pessimistic_connection_handling(db.engine)
|
||||
|
||||
migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations")
|
||||
|
||||
def configure_wtf(self):
|
||||
if self.config["WTF_CSRF_ENABLED"]:
|
||||
csrf = CSRFProtect(self.flask_app)
|
||||
csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"]
|
||||
for ex in csrf_exempt_list:
|
||||
csrf.exempt(ex)
|
||||
|
||||
def register_blueprints(self):
|
||||
for bp in self.config["BLUEPRINTS"]:
|
||||
try:
|
||||
logger.info(f"Registering blueprint: '{bp.name}'")
|
||||
self.flask_app.register_blueprint(bp)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.exception("blueprint registration failed")
|
||||
|
||||
def setup_bundle_manifest(self):
|
||||
manifest_processor.init_app(self.flask_app)
|
|
@ -15,17 +15,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import click
|
||||
from flask.cli import FlaskGroup
|
||||
|
||||
from superset.cli import create_app
|
||||
|
||||
|
||||
@click.group(cls=FlaskGroup, create_app=create_app)
|
||||
def cli():
|
||||
"""This is a management script for the Superset application."""
|
||||
pass
|
||||
|
||||
from superset.cli import superset
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
superset()
|
||||
|
|
|
@ -25,27 +25,28 @@ import click
|
|||
import yaml
|
||||
from colorama import Fore, Style
|
||||
from flask import g
|
||||
from flask.cli import FlaskGroup, with_appcontext
|
||||
from flask_appbuilder import Model
|
||||
from pathlib2 import Path
|
||||
|
||||
from superset import app, appbuilder, db, examples, security_manager
|
||||
from superset.common.tags import add_favorites, add_owners, add_types
|
||||
from superset.utils import core as utils, dashboard_import_export, dict_import_export
|
||||
|
||||
config = app.config
|
||||
celery_app = utils.get_celery_app(config)
|
||||
from superset import app, appbuilder, security_manager
|
||||
from superset.app import create_app
|
||||
from superset.extensions import celery_app, db
|
||||
from superset.utils import core as utils
|
||||
|
||||
|
||||
def create_app(script_info=None):
|
||||
return app
|
||||
@click.group(cls=FlaskGroup, create_app=create_app)
|
||||
@with_appcontext
|
||||
def superset():
|
||||
"""This is a management script for the Superset application."""
|
||||
|
||||
@app.shell_context_processor
|
||||
def make_shell_context():
|
||||
return dict(app=app, db=db)
|
||||
|
||||
|
||||
@app.shell_context_processor
|
||||
def make_shell_context():
|
||||
return dict(app=app, db=db)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
def init():
|
||||
"""Inits the Superset application"""
|
||||
utils.get_example_database()
|
||||
|
@ -53,7 +54,8 @@ def init():
|
|||
security_manager.sync_role_definitions()
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option("--verbose", "-v", is_flag=True, help="Show extra information")
|
||||
def version(verbose):
|
||||
"""Prints the current version number"""
|
||||
|
@ -62,7 +64,7 @@ def version(verbose):
|
|||
Fore.YELLOW
|
||||
+ "Superset "
|
||||
+ Fore.CYAN
|
||||
+ "{version}".format(version=config["VERSION_STRING"])
|
||||
+ "{version}".format(version=app.config["VERSION_STRING"])
|
||||
)
|
||||
print(Fore.BLUE + "-=" * 15)
|
||||
if verbose:
|
||||
|
@ -77,6 +79,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
|
|||
examples_db = utils.get_example_database()
|
||||
print(f"Loading examples metadata and related data into {examples_db}")
|
||||
|
||||
from superset import examples
|
||||
|
||||
examples.load_css_templates()
|
||||
|
||||
print("Loading energy related dataset")
|
||||
|
@ -129,7 +133,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
|
|||
examples.load_tabbed_dashboard(only_metadata)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@with_appcontext
|
||||
@superset.command()
|
||||
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
|
||||
@click.option(
|
||||
"--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data"
|
||||
|
@ -142,7 +147,8 @@ def load_examples(load_test_data, only_metadata=False, force=False):
|
|||
load_examples_run(load_test_data, only_metadata, force)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@with_appcontext
|
||||
@superset.command()
|
||||
@click.option("--database_name", "-d", help="Database name to change")
|
||||
@click.option("--uri", "-u", help="Database URI to change")
|
||||
def set_database_uri(database_name, uri):
|
||||
|
@ -150,7 +156,8 @@ def set_database_uri(database_name, uri):
|
|||
utils.get_or_create_db(database_name, uri)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--datasource",
|
||||
"-d",
|
||||
|
@ -180,7 +187,8 @@ def refresh_druid(datasource, merge):
|
|||
session.commit()
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--path",
|
||||
"-p",
|
||||
|
@ -202,6 +210,8 @@ def refresh_druid(datasource, merge):
|
|||
)
|
||||
def import_dashboards(path, recursive, username):
|
||||
"""Import dashboards from JSON"""
|
||||
from superset.utils import dashboard_import_export
|
||||
|
||||
p = Path(path)
|
||||
files = []
|
||||
if p.is_file():
|
||||
|
@ -222,7 +232,8 @@ def import_dashboards(path, recursive, username):
|
|||
logging.error(e)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--dashboard-file", "-f", default=None, help="Specify the the file to export to"
|
||||
)
|
||||
|
@ -231,6 +242,8 @@ def import_dashboards(path, recursive, username):
|
|||
)
|
||||
def export_dashboards(print_stdout, dashboard_file):
|
||||
"""Export dashboards to JSON"""
|
||||
from superset.utils import dashboard_import_export
|
||||
|
||||
data = dashboard_import_export.export_dashboards(db.session)
|
||||
if print_stdout or not dashboard_file:
|
||||
print(data)
|
||||
|
@ -240,7 +253,8 @@ def export_dashboards(print_stdout, dashboard_file):
|
|||
data_stream.write(data)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--path",
|
||||
"-p",
|
||||
|
@ -265,6 +279,8 @@ def export_dashboards(print_stdout, dashboard_file):
|
|||
)
|
||||
def import_datasources(path, sync, recursive):
|
||||
"""Import datasources from YAML"""
|
||||
from superset.utils import dict_import_export
|
||||
|
||||
sync_array = sync.split(",")
|
||||
p = Path(path)
|
||||
files = []
|
||||
|
@ -288,7 +304,8 @@ def import_datasources(path, sync, recursive):
|
|||
logging.error(e)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--datasource-file", "-f", default=None, help="Specify the the file to export to"
|
||||
)
|
||||
|
@ -313,6 +330,8 @@ def export_datasources(
|
|||
print_stdout, datasource_file, back_references, include_defaults
|
||||
):
|
||||
"""Export datasources to YAML"""
|
||||
from superset.utils import dict_import_export
|
||||
|
||||
data = dict_import_export.export_to_dict(
|
||||
session=db.session,
|
||||
recursive=True,
|
||||
|
@ -327,7 +346,8 @@ def export_datasources(
|
|||
yaml.safe_dump(data, data_stream, default_flow_style=False)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--back-references",
|
||||
"-b",
|
||||
|
@ -337,11 +357,14 @@ def export_datasources(
|
|||
)
|
||||
def export_datasource_schema(back_references):
|
||||
"""Export datasource YAML schema to stdout"""
|
||||
from superset.utils import dict_import_export
|
||||
|
||||
data = dict_import_export.export_schema_to_dict(back_references=back_references)
|
||||
yaml.safe_dump(data, stdout, default_flow_style=False)
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
def update_datasources_cache():
|
||||
"""Refresh sqllab datasources cache"""
|
||||
from superset.models.core import Database
|
||||
|
@ -360,7 +383,8 @@ def update_datasources_cache():
|
|||
print("{}".format(str(e)))
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"--workers", "-w", type=int, help="Number of celery server workers to fire up"
|
||||
)
|
||||
|
@ -372,14 +396,17 @@ def worker(workers):
|
|||
)
|
||||
if workers:
|
||||
celery_app.conf.update(CELERYD_CONCURRENCY=workers)
|
||||
elif config["SUPERSET_CELERY_WORKERS"]:
|
||||
celery_app.conf.update(CELERYD_CONCURRENCY=config["SUPERSET_CELERY_WORKERS"])
|
||||
elif app.config["SUPERSET_CELERY_WORKERS"]:
|
||||
celery_app.conf.update(
|
||||
CELERYD_CONCURRENCY=app.config["SUPERSET_CELERY_WORKERS"]
|
||||
)
|
||||
|
||||
worker = celery_app.Worker(optimization="fair")
|
||||
worker.start()
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
@click.option(
|
||||
"-p", "--port", default="5555", help="Port on which to start the Flower process"
|
||||
)
|
||||
|
@ -409,7 +436,8 @@ def flower(port, address):
|
|||
Popen(cmd, shell=True).wait()
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
def load_test_users():
|
||||
"""
|
||||
Loads admin, alpha, and gamma user for testing purposes
|
||||
|
@ -426,7 +454,7 @@ def load_test_users_run():
|
|||
|
||||
Syncs permissions for those users/roles
|
||||
"""
|
||||
if config["TESTING"]:
|
||||
if app.config["TESTING"]:
|
||||
|
||||
sm = security_manager
|
||||
|
||||
|
@ -463,11 +491,15 @@ def load_test_users_run():
|
|||
sm.get_session.commit()
|
||||
|
||||
|
||||
@app.cli.command()
|
||||
@superset.command()
|
||||
@with_appcontext
|
||||
def sync_tags():
|
||||
"""Rebuilds special tags (owner, type, favorited by)."""
|
||||
# pylint: disable=no-member
|
||||
metadata = Model.metadata
|
||||
|
||||
from superset.common.tags import add_favorites, add_owners, add_types
|
||||
|
||||
add_types(db.engine, metadata)
|
||||
add_owners(db.engine, metadata)
|
||||
add_favorites(db.engine, metadata)
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import os
|
||||
|
||||
import celery
|
||||
from flask_appbuilder import AppBuilder, SQLA
|
||||
from flask_migrate import Migrate
|
||||
from flask_talisman import Talisman
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
from superset.utils.feature_flag_manager import FeatureFlagManager
|
||||
|
||||
|
||||
class ResultsBackendManager:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._results_backend = None
|
||||
self._use_msgpack = False
|
||||
|
||||
def init_app(self, app):
|
||||
self._results_backend = app.config.get("RESULTS_BACKEND")
|
||||
self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
|
||||
|
||||
@property
|
||||
def results_backend(self):
|
||||
return self._results_backend
|
||||
|
||||
@property
|
||||
def should_use_msgpack(self):
|
||||
return self._use_msgpack
|
||||
|
||||
|
||||
class UIManifestProcessor:
|
||||
def __init__(self, app_dir: str) -> None:
|
||||
super().__init__()
|
||||
self.app = None
|
||||
self.manifest: dict = {}
|
||||
self.manifest_file = f"{app_dir}/static/assets/dist/manifest.json"
|
||||
|
||||
def init_app(self, app):
|
||||
self.app = app
|
||||
# Preload the cache
|
||||
self.parse_manifest_json()
|
||||
|
||||
@app.context_processor
|
||||
def get_manifest(): # pylint: disable=unused-variable
|
||||
return dict(
|
||||
loaded_chunks=set(),
|
||||
get_unloaded_chunks=self.get_unloaded_chunks,
|
||||
js_manifest=self.get_js_manifest_files,
|
||||
css_manifest=self.get_css_manifest_files,
|
||||
)
|
||||
|
||||
def parse_manifest_json(self):
|
||||
try:
|
||||
with open(self.manifest_file, "r") as f:
|
||||
# the manifest includes non-entry files
|
||||
# we only need entries in templates
|
||||
full_manifest = json.load(f)
|
||||
self.manifest = full_manifest.get("entrypoints", {})
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
def get_js_manifest_files(self, filename):
|
||||
if self.app.debug:
|
||||
self.parse_manifest_json()
|
||||
entry_files = self.manifest.get(filename, {})
|
||||
return entry_files.get("js", [])
|
||||
|
||||
def get_css_manifest_files(self, filename):
|
||||
if self.app.debug:
|
||||
self.parse_manifest_json()
|
||||
entry_files = self.manifest.get(filename, {})
|
||||
return entry_files.get("css", [])
|
||||
|
||||
@staticmethod
|
||||
def get_unloaded_chunks(files, loaded_chunks):
|
||||
filtered_files = [f for f in files if f not in loaded_chunks]
|
||||
for f in filtered_files:
|
||||
loaded_chunks.add(f)
|
||||
return filtered_files
|
||||
|
||||
|
||||
APP_DIR = os.path.dirname(__file__)
|
||||
|
||||
appbuilder = AppBuilder(update_perms=False)
|
||||
cache_manager = CacheManager()
|
||||
celery_app = celery.Celery()
|
||||
db = SQLA()
|
||||
_event_logger: dict = {}
|
||||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||
feature_flag_manager = FeatureFlagManager()
|
||||
manifest_processor = UIManifestProcessor(APP_DIR)
|
||||
migrate = Migrate()
|
||||
results_backend_manager = ResultsBackendManager()
|
||||
security_manager = LocalProxy(lambda: appbuilder.sm)
|
||||
talisman = Talisman()
|
|
@ -19,10 +19,6 @@
|
|||
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
|
||||
from wtforms import Field
|
||||
|
||||
from superset import app
|
||||
|
||||
config = app.config
|
||||
|
||||
|
||||
class CommaSeparatedListField(Field):
|
||||
widget = BS3TextFieldWidget()
|
||||
|
|
|
@ -42,9 +42,9 @@ from superset import (
|
|||
)
|
||||
from superset.dataframe import SupersetDataFrame
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.tasks.celery_app import app as celery_app
|
||||
from superset.utils.core import json_iso_dttm_ser, QueryStatus, sources, zlib_compress
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
|
|
@ -15,4 +15,3 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from . import cache, schedules
|
||||
|
|
|
@ -25,9 +25,9 @@ from celery.utils.log import get_task_logger
|
|||
from sqlalchemy import and_, func
|
||||
|
||||
from superset import app, db
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.core import Dashboard, Log, Slice
|
||||
from superset.models.tags import Tag, TaggedObject
|
||||
from superset.tasks.celery_app import app as celery_app
|
||||
from superset.utils.core import parse_human_datetime
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
|
|
@ -16,12 +16,20 @@
|
|||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
|
||||
"""Utility functions used across Superset"""
|
||||
"""
|
||||
This is the main entrypoint used by Celery workers. As such,
|
||||
it needs to call create_app() in order to initialize things properly
|
||||
"""
|
||||
|
||||
# Superset framework imports
|
||||
from superset import app
|
||||
from superset.utils.core import get_celery_app
|
||||
from superset import create_app
|
||||
from superset.extensions import celery_app
|
||||
|
||||
# Globals
|
||||
config = app.config
|
||||
app = get_celery_app(config)
|
||||
# Init the Flask app / configure everything
|
||||
create_app()
|
||||
|
||||
# Need to import late, as the celery_app will have been setup by "create_app()"
|
||||
from . import cache, schedules # isort:skip
|
||||
|
||||
# Export the celery app globally for Celery (as run on the cmd line) to find
|
||||
app = celery_app
|
||||
|
|
|
@ -39,13 +39,13 @@ from werkzeug.http import parse_cookie
|
|||
|
||||
# Superset framework imports
|
||||
from superset import app, db, security_manager
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.schedules import (
|
||||
EmailDeliveryType,
|
||||
get_scheduler_model,
|
||||
ScheduleType,
|
||||
SliceEmailReportFormat,
|
||||
)
|
||||
from superset.tasks.celery_app import app as celery_app
|
||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||
|
||||
# Globals
|
||||
|
|
|
@ -15,9 +15,12 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from flask import request
|
||||
from typing import Optional
|
||||
|
||||
from superset import tables_cache
|
||||
from flask import Flask, request
|
||||
from flask_caching import Cache
|
||||
|
||||
from superset.extensions import cache_manager
|
||||
|
||||
|
||||
def view_cache_key(*unused_args, **unused_kwargs) -> str:
|
||||
|
@ -43,7 +46,7 @@ def memoized_func(key=view_cache_key, attribute_in_key=None):
|
|||
"""
|
||||
|
||||
def wrap(f):
|
||||
if tables_cache:
|
||||
if cache_manager.tables_cache:
|
||||
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
if not kwargs.get("cache", True):
|
||||
|
@ -55,11 +58,13 @@ def memoized_func(key=view_cache_key, attribute_in_key=None):
|
|||
)
|
||||
else:
|
||||
cache_key = key(*args, **kwargs)
|
||||
o = tables_cache.get(cache_key)
|
||||
o = cache_manager.tables_cache.get(cache_key)
|
||||
if not kwargs.get("force") and o is not None:
|
||||
return o
|
||||
o = f(self, *args, **kwargs)
|
||||
tables_cache.set(cache_key, o, timeout=kwargs.get("cache_timeout"))
|
||||
cache_manager.tables_cache.set(
|
||||
cache_key, o, timeout=kwargs.get("cache_timeout")
|
||||
)
|
||||
return o
|
||||
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask
|
||||
from flask_caching import Cache
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tables_cache = None
|
||||
self._cache = None
|
||||
|
||||
def init_app(self, app):
|
||||
self._cache = self._setup_cache(app, app.config.get("CACHE_CONFIG"))
|
||||
self._tables_cache = self._setup_cache(
|
||||
app, app.config.get("TABLE_NAMES_CACHE_CONFIG")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup_cache(app: Flask, cache_config) -> Optional[Cache]:
|
||||
"""Setup the flask-cache on a flask app"""
|
||||
if cache_config:
|
||||
if isinstance(cache_config, dict):
|
||||
if cache_config.get("CACHE_TYPE") != "null":
|
||||
return Cache(app, config=cache_config)
|
||||
else:
|
||||
# Accepts a custom cache initialization function,
|
||||
# returning an object compatible with Flask-Caching API
|
||||
return cache_config(app)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def tables_cache(self):
|
||||
return self._tables_cache
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._cache
|
|
@ -791,20 +791,6 @@ def choicify(values):
|
|||
return [(v, v) for v in values]
|
||||
|
||||
|
||||
def setup_cache(app: Flask, cache_config) -> Optional[Cache]:
|
||||
"""Setup the flask-cache on a flask app"""
|
||||
if cache_config:
|
||||
if isinstance(cache_config, dict):
|
||||
if cache_config["CACHE_TYPE"] != "null":
|
||||
return Cache(app, config=cache_config)
|
||||
else:
|
||||
# Accepts a custom cache initialization function,
|
||||
# returning an object compatible with Flask-Caching API
|
||||
return cache_config(app)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def zlib_compress(data):
|
||||
"""
|
||||
Compress things in a py2/3 safe fashion
|
||||
|
@ -832,19 +818,6 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
|
|||
return decompressed.decode("utf-8") if decode else decompressed
|
||||
|
||||
|
||||
_celery_app = None
|
||||
|
||||
|
||||
def get_celery_app(config):
|
||||
global _celery_app
|
||||
if _celery_app:
|
||||
return _celery_app
|
||||
_celery_app = celery.Celery()
|
||||
_celery_app.config_from_object(config["CELERY_CONFIG"])
|
||||
_celery_app.set_default()
|
||||
return _celery_app
|
||||
|
||||
|
||||
def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
|
||||
result = {
|
||||
"clause": clause.upper(),
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class FeatureFlagManager:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._get_feature_flags_func = None
|
||||
self._feature_flags = None
|
||||
|
||||
def init_app(self, app):
|
||||
self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC")
|
||||
self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
|
||||
self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
|
||||
|
||||
def get_feature_flags(self):
|
||||
if self._get_feature_flags_func:
|
||||
return self._get_feature_flags_func(deepcopy(self._feature_flags))
|
||||
|
||||
return self._feature_flags
|
||||
|
||||
def is_feature_enabled(self, feature):
|
||||
"""Utility function for checking whether a feature is turned on"""
|
||||
return self.get_feature_flags().get(feature)
|
|
@ -686,7 +686,9 @@ class KV(BaseSupersetView):
|
|||
def get_value(self, key_id):
|
||||
kv = None
|
||||
try:
|
||||
kv = db.session.query(models.KeyValue).filter_by(id=key_id).one()
|
||||
kv = db.session.query(models.KeyValue).filter_by(id=key_id).scalar()
|
||||
if not kv:
|
||||
return Response(status=404, content_type="text/plain")
|
||||
except Exception as e:
|
||||
return json_error_response(e)
|
||||
return Response(kv.value, status=200, content_type="text/plain")
|
||||
|
@ -736,6 +738,8 @@ appbuilder.add_view_no_menu(R)
|
|||
class Superset(BaseSupersetView):
|
||||
"""The base views for Superset!"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@has_access_api
|
||||
@expose("/datasources/")
|
||||
def datasources(self):
|
||||
|
@ -2059,6 +2063,7 @@ class Superset(BaseSupersetView):
|
|||
)
|
||||
obj.get_json()
|
||||
except Exception as e:
|
||||
self.logger.exception("Failed to warm up cache")
|
||||
return json_error_response(utils.error_msg_from_exception(e))
|
||||
return json_success(
|
||||
json.dumps(
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from superset import app, db, security_manager
|
||||
from tests.test_app import app # isort:skip
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.druid.models import DruidDatasource
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
@ -94,22 +96,24 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
|||
class RequestAccessTests(SupersetTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
security_manager.add_role("override_me")
|
||||
security_manager.add_role(TEST_ROLE_1)
|
||||
security_manager.add_role(TEST_ROLE_2)
|
||||
security_manager.add_role(DB_ACCESS_ROLE)
|
||||
security_manager.add_role(SCHEMA_ACCESS_ROLE)
|
||||
db.session.commit()
|
||||
with app.app_context():
|
||||
security_manager.add_role("override_me")
|
||||
security_manager.add_role(TEST_ROLE_1)
|
||||
security_manager.add_role(TEST_ROLE_2)
|
||||
security_manager.add_role(DB_ACCESS_ROLE)
|
||||
security_manager.add_role(SCHEMA_ACCESS_ROLE)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
override_me = security_manager.find_role("override_me")
|
||||
db.session.delete(override_me)
|
||||
db.session.delete(security_manager.find_role(TEST_ROLE_1))
|
||||
db.session.delete(security_manager.find_role(TEST_ROLE_2))
|
||||
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
|
||||
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
db.session.commit()
|
||||
with app.app_context():
|
||||
override_me = security_manager.find_role("override_me")
|
||||
db.session.delete(override_me)
|
||||
db.session.delete(security_manager.find_role(TEST_ROLE_1))
|
||||
db.session.delete(security_manager.find_role(TEST_ROLE_2))
|
||||
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
|
||||
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
db.session.commit()
|
||||
|
||||
def setUp(self):
|
||||
self.login("admin")
|
||||
|
|
|
@ -14,52 +14,57 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import imp
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pandas as pd
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_testing import TestCase
|
||||
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
from tests.test_app import app # isort:skip
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_database
|
||||
|
||||
BASE_DIR = app.config["BASE_DIR"]
|
||||
FAKE_DB_NAME = "fake_db_100"
|
||||
|
||||
|
||||
class SupersetTestCase(unittest.TestCase):
|
||||
class SupersetTestCase(TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SupersetTestCase, self).__init__(*args, **kwargs)
|
||||
self.client = app.test_client()
|
||||
self.maxDiff = None
|
||||
|
||||
def create_app(self):
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def create_druid_test_objects(cls):
|
||||
# create druid cluster and druid datasources
|
||||
session = db.session
|
||||
cluster = (
|
||||
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
|
||||
)
|
||||
if not cluster:
|
||||
cluster = DruidCluster(cluster_name="druid_test")
|
||||
session.add(cluster)
|
||||
session.commit()
|
||||
with app.app_context():
|
||||
session = db.session
|
||||
cluster = (
|
||||
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
|
||||
)
|
||||
if not cluster:
|
||||
cluster = DruidCluster(cluster_name="druid_test")
|
||||
session.add(cluster)
|
||||
session.commit()
|
||||
|
||||
druid_datasource1 = DruidDatasource(
|
||||
datasource_name="druid_ds_1", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource1)
|
||||
druid_datasource2 = DruidDatasource(
|
||||
datasource_name="druid_ds_2", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
druid_datasource1 = DruidDatasource(
|
||||
datasource_name="druid_ds_1", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource1)
|
||||
druid_datasource2 = DruidDatasource(
|
||||
datasource_name="druid_ds_2", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
|
||||
def get_table(self, table_id):
|
||||
return db.session.query(SqlaTable).filter_by(id=table_id).one()
|
||||
|
@ -210,7 +215,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
|
||||
def create_fake_db(self):
|
||||
self.login(username="admin")
|
||||
database_name = "fake_db_100"
|
||||
database_name = FAKE_DB_NAME
|
||||
db_id = 100
|
||||
extra = """{
|
||||
"schemas_allowed_for_csv_upload":
|
||||
|
@ -225,6 +230,15 @@ class SupersetTestCase(unittest.TestCase):
|
|||
extra=extra,
|
||||
)
|
||||
|
||||
def delete_fake_db(self):
|
||||
database = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == FAKE_DB_NAME)
|
||||
.scalar()
|
||||
)
|
||||
if database:
|
||||
db.session.delete(database)
|
||||
|
||||
def validate_sql(
|
||||
self,
|
||||
sql,
|
||||
|
@ -246,18 +260,6 @@ class SupersetTestCase(unittest.TestCase):
|
|||
raise Exception("validate_sql failed")
|
||||
return resp
|
||||
|
||||
@patch.dict("superset._feature_flags", {"FOO": True}, clear=True)
|
||||
def test_existing_feature_flags(self):
|
||||
self.assertTrue(is_feature_enabled("FOO"))
|
||||
|
||||
@patch.dict("superset._feature_flags", {}, clear=True)
|
||||
def test_nonexistent_feature_flags(self):
|
||||
self.assertFalse(is_feature_enabled("FOO"))
|
||||
|
||||
def test_feature_flags(self):
|
||||
self.assertEqual(is_feature_enabled("foo"), "bar")
|
||||
self.assertEqual(is_feature_enabled("super"), "set")
|
||||
|
||||
def get_dash_by_slug(self, dash_slug):
|
||||
sesh = db.session()
|
||||
return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset Celery worker"""
|
||||
import datetime
|
||||
import json
|
||||
|
@ -22,7 +23,8 @@ import time
|
|||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from superset import app, db, sql_lab
|
||||
from tests.test_app import app # isort:skip
|
||||
from superset import db, sql_lab
|
||||
from superset.dataframe import SupersetDataFrame
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.helpers import QueryStatus
|
||||
|
@ -32,20 +34,9 @@ from superset.utils.core import get_example_database
|
|||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
BASE_DIR = app.config["BASE_DIR"]
|
||||
CELERY_SLEEP_TIME = 5
|
||||
|
||||
|
||||
class CeleryConfig(object):
|
||||
BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL
|
||||
CELERY_IMPORTS = ("superset.sql_lab",)
|
||||
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
|
||||
CONCURRENCY = 1
|
||||
|
||||
|
||||
app.config["CELERY_CONFIG"] = CeleryConfig
|
||||
|
||||
|
||||
class UtilityFunctionTests(SupersetTestCase):
|
||||
|
||||
# TODO(bkyryliuk): support more cases in CTA function.
|
||||
|
@ -79,10 +70,6 @@ class UtilityFunctionTests(SupersetTestCase):
|
|||
|
||||
|
||||
class CeleryTestCase(SupersetTestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CeleryTestCase, self).__init__(*args, **kwargs)
|
||||
self.client = app.test_client()
|
||||
|
||||
def get_query_by_name(self, sql):
|
||||
session = db.session
|
||||
query = session.query(Query).filter_by(sql=sql).first()
|
||||
|
@ -97,11 +84,22 @@ class CeleryTestCase(SupersetTestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
with app.app_context():
|
||||
|
||||
worker_command = BASE_DIR + "/bin/superset worker -w 2"
|
||||
subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE)
|
||||
class CeleryConfig(object):
|
||||
BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL
|
||||
CELERY_IMPORTS = ("superset.sql_lab",)
|
||||
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
|
||||
CONCURRENCY = 1
|
||||
|
||||
app.config["CELERY_CONFIG"] = CeleryConfig
|
||||
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
|
||||
base_dir = app.config["BASE_DIR"]
|
||||
worker_command = base_dir + "/bin/superset worker -w 2"
|
||||
subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -190,6 +188,7 @@ class CeleryTestCase(SupersetTestCase):
|
|||
result = self.run_sql(
|
||||
db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
|
||||
)
|
||||
db.session.close()
|
||||
assert result["query"]["state"] in (
|
||||
QueryStatus.PENDING,
|
||||
QueryStatus.RUNNING,
|
||||
|
@ -224,6 +223,7 @@ class CeleryTestCase(SupersetTestCase):
|
|||
result = self.run_sql(
|
||||
db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
|
||||
)
|
||||
db.session.close()
|
||||
assert result["query"]["state"] in (
|
||||
QueryStatus.PENDING,
|
||||
QueryStatus.RUNNING,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import cgi
|
||||
import csv
|
||||
|
@ -33,7 +34,8 @@ import pandas as pd
|
|||
import psycopg2
|
||||
import sqlalchemy as sqla
|
||||
|
||||
from superset import app, dataframe, db, jinja_context, security_manager, sql_lab
|
||||
from tests.test_app import app
|
||||
from superset import dataframe, db, jinja_context, security_manager, sql_lab
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
|
@ -51,16 +53,13 @@ class CoreTests(SupersetTestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(CoreTests, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.table_ids = {
|
||||
tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all())
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
db.session.query(Query).delete()
|
||||
db.session.query(models.DatasourceAccessRequest).delete()
|
||||
db.session.query(models.Log).delete()
|
||||
self.table_ids = {
|
||||
tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all())
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
db.session.query(Query).delete()
|
||||
|
@ -196,12 +195,11 @@ class CoreTests(SupersetTestCase):
|
|||
|
||||
def test_save_slice(self):
|
||||
self.login(username="admin")
|
||||
slice_name = "Energy Sankey"
|
||||
slice_name = f"Energy Sankey"
|
||||
slice_id = self.get_slice(slice_name, db.session).id
|
||||
db.session.commit()
|
||||
copy_name = "Test Sankey Save"
|
||||
copy_name = f"Test Sankey Save_{random.random()}"
|
||||
tbl_id = self.table_ids.get("energy_usage")
|
||||
new_slice_name = "Test Sankey Overwirte"
|
||||
new_slice_name = f"Test Sankey Overwrite_{random.random()}"
|
||||
|
||||
url = (
|
||||
"/superset/explore/table/{}/?slice_name={}&"
|
||||
|
@ -216,13 +214,17 @@ class CoreTests(SupersetTestCase):
|
|||
"slice_id": slice_id,
|
||||
}
|
||||
# Changing name and save as a new slice
|
||||
self.get_resp(
|
||||
resp = self.client.post(
|
||||
url.format(tbl_id, copy_name, "saveas"),
|
||||
{"form_data": json.dumps(form_data)},
|
||||
data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
slices = db.session.query(models.Slice).filter_by(slice_name=copy_name).all()
|
||||
assert len(slices) == 1
|
||||
new_slice_id = slices[0].id
|
||||
db.session.expunge_all()
|
||||
new_slice_id = resp.json["form_data"]["slice_id"]
|
||||
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
|
||||
|
||||
self.assertEqual(slc.slice_name, copy_name)
|
||||
form_data.pop("slice_id") # We don't save the slice id when saving as
|
||||
self.assertEqual(slc.viz.form_data, form_data)
|
||||
|
||||
form_data = {
|
||||
"viz_type": "sankey",
|
||||
|
@ -233,14 +235,18 @@ class CoreTests(SupersetTestCase):
|
|||
"time_range": "now",
|
||||
}
|
||||
# Setting the name back to its original name by overwriting new slice
|
||||
self.get_resp(
|
||||
self.client.post(
|
||||
url.format(tbl_id, new_slice_name, "overwrite"),
|
||||
{"form_data": json.dumps(form_data)},
|
||||
data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).first()
|
||||
assert slc.slice_name == new_slice_name
|
||||
assert slc.viz.form_data == form_data
|
||||
db.session.expunge_all()
|
||||
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
|
||||
self.assertEqual(slc.slice_name, new_slice_name)
|
||||
self.assertEqual(slc.viz.form_data, form_data)
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(slc)
|
||||
db.session.commit()
|
||||
|
||||
def test_filter_endpoint(self):
|
||||
self.login(username="admin")
|
||||
|
@ -406,10 +412,16 @@ class CoreTests(SupersetTestCase):
|
|||
database = utils.get_example_database()
|
||||
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
|
||||
|
||||
# Need to clean up after ourselves
|
||||
database.impersonate_user = False
|
||||
database.allow_dml = False
|
||||
database.allow_run_async = False
|
||||
db.session.commit()
|
||||
|
||||
def test_warm_up_cache(self):
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
|
||||
assert data == [{"slice_id": slc.id, "slice_name": slc.slice_name}]
|
||||
self.assertEqual(data, [{"slice_id": slc.id, "slice_name": slc.slice_name}])
|
||||
|
||||
data = self.get_json_resp(
|
||||
"/superset/warm_up_cache?table_name=energy_usage&db_name=main"
|
||||
|
@ -430,13 +442,10 @@ class CoreTests(SupersetTestCase):
|
|||
assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))
|
||||
|
||||
def test_kv(self):
|
||||
self.logout()
|
||||
self.login(username="admin")
|
||||
|
||||
try:
|
||||
resp = self.client.post("/kv/store/", data=dict())
|
||||
except Exception:
|
||||
self.assertRaises(TypeError)
|
||||
resp = self.client.get("/kv/10001/")
|
||||
self.assertEqual(404, resp.status_code)
|
||||
|
||||
value = json.dumps({"data": "this is a test"})
|
||||
resp = self.client.post("/kv/store/", data=dict(data=value))
|
||||
|
@ -449,11 +458,6 @@ class CoreTests(SupersetTestCase):
|
|||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
|
||||
|
||||
try:
|
||||
resp = self.client.get("/kv/10001/")
|
||||
except Exception:
|
||||
self.assertRaises(TypeError)
|
||||
|
||||
def test_gamma(self):
|
||||
self.login(username="gamma")
|
||||
assert "Charts" in self.get_resp("/chart/list/")
|
||||
|
@ -808,6 +812,7 @@ class CoreTests(SupersetTestCase):
|
|||
)
|
||||
)
|
||||
assert data == ["this_schema_is_allowed_too"]
|
||||
self.delete_fake_db()
|
||||
|
||||
def test_select_star(self):
|
||||
self.login(username="admin")
|
||||
|
@ -950,7 +955,11 @@ class CoreTests(SupersetTestCase):
|
|||
self.assertDictEqual(deserialized_payload, payload)
|
||||
expand_data.assert_called_once()
|
||||
|
||||
@mock.patch.dict("superset._feature_flags", {"FOO": lambda x: 1}, clear=True)
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"FOO": lambda x: 1},
|
||||
clear=True,
|
||||
)
|
||||
def test_feature_flag_serialization(self):
|
||||
"""
|
||||
Functions in feature flags don't break bootstrap data serialization.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from random import random
|
||||
|
||||
from flask import escape
|
||||
from sqlalchemy import func
|
||||
|
@ -399,16 +400,19 @@ class DashboardTests(SupersetTestCase):
|
|||
|
||||
self.grant_public_access_to_table(table)
|
||||
|
||||
hidden_dash_slug = f"hidden_dash_{random()}"
|
||||
published_dash_slug = f"published_dash_{random()}"
|
||||
|
||||
# Create a published and hidden dashboard and add them to the database
|
||||
published_dash = models.Dashboard()
|
||||
published_dash.dashboard_title = "Published Dashboard"
|
||||
published_dash.slug = "published_dash"
|
||||
published_dash.slug = published_dash_slug
|
||||
published_dash.slices = [slice]
|
||||
published_dash.published = True
|
||||
|
||||
hidden_dash = models.Dashboard()
|
||||
hidden_dash.dashboard_title = "Hidden Dashboard"
|
||||
hidden_dash.slug = "hidden_dash"
|
||||
hidden_dash.slug = hidden_dash_slug
|
||||
hidden_dash.slices = [slice]
|
||||
hidden_dash.published = False
|
||||
|
||||
|
@ -417,22 +421,24 @@ class DashboardTests(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
resp = self.get_resp("/dashboard/list/")
|
||||
self.assertNotIn("/superset/dashboard/hidden_dash/", resp)
|
||||
self.assertIn("/superset/dashboard/published_dash/", resp)
|
||||
self.assertNotIn(f"/superset/dashboard/{hidden_dash_slug}/", resp)
|
||||
self.assertIn(f"/superset/dashboard/{published_dash_slug}/", resp)
|
||||
|
||||
def test_users_can_view_own_dashboard(self):
|
||||
user = security_manager.find_user("gamma")
|
||||
my_dash_slug = f"my_dash_{random()}"
|
||||
not_my_dash_slug = f"not_my_dash_{random()}"
|
||||
|
||||
# Create one dashboard I own and another that I don't
|
||||
dash = models.Dashboard()
|
||||
dash.dashboard_title = "My Dashboard"
|
||||
dash.slug = "my_dash"
|
||||
dash.slug = my_dash_slug
|
||||
dash.owners = [user]
|
||||
dash.slices = []
|
||||
|
||||
hidden_dash = models.Dashboard()
|
||||
hidden_dash.dashboard_title = "Not My Dashboard"
|
||||
hidden_dash.slug = "not_my_dash"
|
||||
hidden_dash.slug = not_my_dash_slug
|
||||
hidden_dash.slices = []
|
||||
hidden_dash.owners = []
|
||||
|
||||
|
@ -443,29 +449,27 @@ class DashboardTests(SupersetTestCase):
|
|||
self.login(user.username)
|
||||
|
||||
resp = self.get_resp("/dashboard/list/")
|
||||
self.assertIn("/superset/dashboard/my_dash/", resp)
|
||||
self.assertNotIn("/superset/dashboard/not_my_dash/", resp)
|
||||
self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp)
|
||||
self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp)
|
||||
|
||||
def test_users_can_view_favorited_dashboards(self):
|
||||
user = security_manager.find_user("gamma")
|
||||
fav_dash_slug = f"my_favorite_dash_{random()}"
|
||||
regular_dash_slug = f"regular_dash_{random()}"
|
||||
|
||||
favorite_dash = models.Dashboard()
|
||||
favorite_dash.dashboard_title = "My Favorite Dashboard"
|
||||
favorite_dash.slug = "my_favorite_dash"
|
||||
favorite_dash.slug = fav_dash_slug
|
||||
|
||||
regular_dash = models.Dashboard()
|
||||
regular_dash.dashboard_title = "A Plain Ol Dashboard"
|
||||
regular_dash.slug = "regular_dash"
|
||||
regular_dash.slug = regular_dash_slug
|
||||
|
||||
db.session.merge(favorite_dash)
|
||||
db.session.merge(regular_dash)
|
||||
db.session.commit()
|
||||
|
||||
dash = (
|
||||
db.session.query(models.Dashboard)
|
||||
.filter_by(slug="my_favorite_dash")
|
||||
.first()
|
||||
)
|
||||
dash = db.session.query(models.Dashboard).filter_by(slug=fav_dash_slug).first()
|
||||
|
||||
favorites = models.FavStar()
|
||||
favorites.obj_id = dash.id
|
||||
|
@ -478,12 +482,12 @@ class DashboardTests(SupersetTestCase):
|
|||
self.login(user.username)
|
||||
|
||||
resp = self.get_resp("/dashboard/list/")
|
||||
self.assertIn("/superset/dashboard/my_favorite_dash/", resp)
|
||||
self.assertIn(f"/superset/dashboard/{fav_dash_slug}/", resp)
|
||||
|
||||
def test_user_can_not_view_unpublished_dash(self):
|
||||
admin_user = security_manager.find_user("admin")
|
||||
gamma_user = security_manager.find_user("gamma")
|
||||
slug = "admin_owned_unpublished_dash"
|
||||
slug = f"admin_owned_unpublished_dash_{random()}"
|
||||
|
||||
# Create a dashboard owned by admin and unpublished
|
||||
dash = models.Dashboard()
|
||||
|
|
|
@ -60,7 +60,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_simple_row_column(self):
|
||||
presto_column = ("column_name", "row(nested_obj double)", "")
|
||||
|
@ -68,7 +70,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
|
||||
presto_column = ("column name", "row(nested_obj double)", "")
|
||||
|
@ -76,7 +80,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
|
||||
presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
|
||||
|
@ -87,7 +93,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_simple_array_column(self):
|
||||
presto_column = ("column_name", "array(double)", "")
|
||||
|
@ -95,7 +103,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_row_within_array_within_row_column(self):
|
||||
presto_column = (
|
||||
|
@ -112,7 +122,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.verify_presto_column(presto_column, expected_results)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_get_array_within_row_within_array_column(self):
|
||||
presto_column = (
|
||||
|
@ -147,7 +159,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.assertEqual(actual_result.name, expected_result["label"])
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_expand_data_with_simple_structural_columns(self):
|
||||
cols = [
|
||||
|
@ -182,7 +196,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_expand_data_with_complex_row_columns(self):
|
||||
cols = [
|
||||
|
@ -229,7 +245,9 @@ class PrestoTests(DbEngineSpecTestCase):
|
|||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"PRESTO_EXPAND_DATA": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_presto_expand_data_with_complex_array_columns(self):
|
||||
cols = [
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.test_app import app
|
||||
from superset import db
|
||||
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
@ -41,15 +43,16 @@ class DictImportExportTests(SupersetTestCase):
|
|||
|
||||
@classmethod
|
||||
def delete_imports(cls):
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for table in session.query(SqlaTable):
|
||||
if DBREF in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
if DBREF in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
with app.app_context():
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for table in session.query(SqlaTable):
|
||||
if DBREF in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
if DBREF in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
@ -47,7 +47,7 @@ def emplace(metrics_dict, metric_name, is_postagg=False):
|
|||
|
||||
|
||||
# Unit tests that can be run without initializing base tests
|
||||
class DruidFuncTestCase(unittest.TestCase):
|
||||
class DruidFuncTestCase(SupersetTestCase):
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
|
||||
)
|
||||
|
|
|
@ -26,13 +26,14 @@ from unittest import mock
|
|||
|
||||
from superset import app
|
||||
from superset.utils import core as utils
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
from .utils import read_fixture
|
||||
|
||||
send_email_test = mock.Mock()
|
||||
|
||||
|
||||
class EmailSmtpTest(unittest.TestCase):
|
||||
class EmailSmtpTest(SupersetTestCase):
|
||||
def setUp(self):
|
||||
app.config["smtp_ssl"] = False
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class FeatureFlagTests(SupersetTestCase):
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"FOO": True},
|
||||
clear=True,
|
||||
)
|
||||
def test_existing_feature_flags(self):
|
||||
self.assertTrue(is_feature_enabled("FOO"))
|
||||
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags", {}, clear=True
|
||||
)
|
||||
def test_nonexistent_feature_flags(self):
|
||||
self.assertFalse(is_feature_enabled("FOO"))
|
||||
|
||||
def test_feature_flags(self):
|
||||
self.assertEqual(is_feature_enabled("foo"), "bar")
|
||||
self.assertEqual(is_feature_enabled("super"), "set")
|
|
@ -14,13 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from flask import Flask, g
|
||||
from flask import g
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
from tests.test_app import app
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
@ -35,21 +37,22 @@ class ImportExportTests(SupersetTestCase):
|
|||
|
||||
@classmethod
|
||||
def delete_imports(cls):
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for slc in session.query(models.Slice):
|
||||
if "remote_id" in slc.params_dict:
|
||||
session.delete(slc)
|
||||
for dash in session.query(models.Dashboard):
|
||||
if "remote_id" in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(SqlaTable):
|
||||
if "remote_id" in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
if "remote_id" in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
with app.app_context():
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for slc in session.query(models.Slice):
|
||||
if "remote_id" in slc.params_dict:
|
||||
session.delete(slc)
|
||||
for dash in session.query(models.Dashboard):
|
||||
if "remote_id" in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(SqlaTable):
|
||||
if "remote_id" in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
if "remote_id" in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -460,68 +463,64 @@ class ImportExportTests(SupersetTestCase):
|
|||
)
|
||||
|
||||
def test_import_new_dashboard_slice_reset_ownership(self):
|
||||
app = Flask("test_import_dashboard_slice_set_user")
|
||||
with app.app_context():
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
|
||||
# set another user as an owner of importing dashboard
|
||||
dash_with_1_slice.created_by = admin_user
|
||||
dash_with_1_slice.changed_by = admin_user
|
||||
dash_with_1_slice.owners = [admin_user]
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
|
||||
# set another user as an owner of importing dashboard
|
||||
dash_with_1_slice.created_by = admin_user
|
||||
dash_with_1_slice.changed_by = admin_user
|
||||
dash_with_1_slice.owners = [admin_user]
|
||||
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
def test_import_override_dashboard_slice_reset_ownership(self):
|
||||
app = Flask("test_import_dashboard_slice_set_user")
|
||||
with app.app_context():
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
admin_user = security_manager.find_user(username="admin")
|
||||
self.assertTrue(admin_user)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
self.assertTrue(gamma_user)
|
||||
g.user = gamma_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
# re-import with another user shouldn't change the permissions
|
||||
g.user = admin_user
|
||||
# re-import with another user shouldn't change the permissions
|
||||
g.user = admin_user
|
||||
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
|
||||
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
|
||||
imported_dash = self.get_dash(imported_dash_id)
|
||||
self.assertEqual(imported_dash.created_by, gamma_user)
|
||||
self.assertEqual(imported_dash.changed_by, gamma_user)
|
||||
self.assertEqual(imported_dash.owners, [gamma_user])
|
||||
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
imported_slc = imported_dash.slices[0]
|
||||
self.assertEqual(imported_slc.created_by, gamma_user)
|
||||
self.assertEqual(imported_slc.changed_by, gamma_user)
|
||||
self.assertEqual(imported_slc.owners, [gamma_user])
|
||||
|
||||
def _create_dashboard_for_import(self, id_=10100):
|
||||
slc = self.create_slice("health_slc" + str(id_), id=id_ + 1)
|
||||
|
|
|
@ -14,27 +14,36 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset import examples
|
||||
from superset.cli import load_test_users_run
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class SupersetDataFrameTestCase(SupersetTestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.examples = None
|
||||
|
||||
def setUp(self) -> None:
|
||||
# Late importing here as we need an app context to be pushed...
|
||||
from superset import examples
|
||||
|
||||
self.examples = examples
|
||||
|
||||
def test_load_css_templates(self):
|
||||
examples.load_css_templates()
|
||||
self.examples.load_css_templates()
|
||||
|
||||
def test_load_energy(self):
|
||||
examples.load_energy()
|
||||
self.examples.load_energy()
|
||||
|
||||
def test_load_world_bank_health_n_pop(self):
|
||||
examples.load_world_bank_health_n_pop()
|
||||
self.examples.load_world_bank_health_n_pop()
|
||||
|
||||
def test_load_birth_names(self):
|
||||
examples.load_birth_names()
|
||||
self.examples.load_birth_names()
|
||||
|
||||
def test_load_test_users_run(self):
|
||||
from superset.cli import load_test_users_run
|
||||
|
||||
load_test_users_run()
|
||||
|
||||
def test_load_unicode_test_data(self):
|
||||
examples.load_unicode_test_data()
|
||||
self.examples.load_unicode_test_data()
|
||||
|
|
|
@ -14,14 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest
|
||||
# isort:skip_file
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, PropertyMock
|
||||
|
||||
from flask_babel import gettext as __
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
|
||||
from superset import app, db
|
||||
from tests.test_app import app
|
||||
from superset import db
|
||||
from superset.models.core import Dashboard, Slice
|
||||
from superset.models.schedules import (
|
||||
DashboardEmailSchedule,
|
||||
|
@ -35,11 +36,12 @@ from superset.tasks.schedules import (
|
|||
deliver_slice,
|
||||
next_schedules,
|
||||
)
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
from .utils import read_fixture
|
||||
|
||||
|
||||
class SchedulesTestCase(unittest.TestCase):
|
||||
class SchedulesTestCase(SupersetTestCase):
|
||||
|
||||
RECIPIENTS = "recipient1@superset.com, recipient2@superset.com"
|
||||
BCC = "bcc@superset.com"
|
||||
|
@ -47,41 +49,45 @@ class SchedulesTestCase(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.common_data = dict(
|
||||
active=True,
|
||||
crontab="* * * * *",
|
||||
recipients=cls.RECIPIENTS,
|
||||
deliver_as_group=True,
|
||||
delivery_type=EmailDeliveryType.inline,
|
||||
)
|
||||
with app.app_context():
|
||||
cls.common_data = dict(
|
||||
active=True,
|
||||
crontab="* * * * *",
|
||||
recipients=cls.RECIPIENTS,
|
||||
deliver_as_group=True,
|
||||
delivery_type=EmailDeliveryType.inline,
|
||||
)
|
||||
|
||||
# Pick up a random slice and dashboard
|
||||
slce = db.session.query(Slice).all()[0]
|
||||
dashboard = db.session.query(Dashboard).all()[0]
|
||||
# Pick up a random slice and dashboard
|
||||
slce = db.session.query(Slice).all()[0]
|
||||
dashboard = db.session.query(Dashboard).all()[0]
|
||||
|
||||
dashboard_schedule = DashboardEmailSchedule(**cls.common_data)
|
||||
dashboard_schedule.dashboard_id = dashboard.id
|
||||
dashboard_schedule.user_id = 1
|
||||
db.session.add(dashboard_schedule)
|
||||
dashboard_schedule = DashboardEmailSchedule(**cls.common_data)
|
||||
dashboard_schedule.dashboard_id = dashboard.id
|
||||
dashboard_schedule.user_id = 1
|
||||
db.session.add(dashboard_schedule)
|
||||
|
||||
slice_schedule = SliceEmailSchedule(**cls.common_data)
|
||||
slice_schedule.slice_id = slce.id
|
||||
slice_schedule.user_id = 1
|
||||
slice_schedule.email_format = SliceEmailReportFormat.data
|
||||
slice_schedule = SliceEmailSchedule(**cls.common_data)
|
||||
slice_schedule.slice_id = slce.id
|
||||
slice_schedule.user_id = 1
|
||||
slice_schedule.email_format = SliceEmailReportFormat.data
|
||||
|
||||
db.session.add(slice_schedule)
|
||||
db.session.commit()
|
||||
db.session.add(slice_schedule)
|
||||
db.session.commit()
|
||||
|
||||
cls.slice_schedule = slice_schedule.id
|
||||
cls.dashboard_schedule = dashboard_schedule.id
|
||||
cls.slice_schedule = slice_schedule.id
|
||||
cls.dashboard_schedule = dashboard_schedule.id
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
db.session.query(SliceEmailSchedule).filter_by(id=cls.slice_schedule).delete()
|
||||
db.session.query(DashboardEmailSchedule).filter_by(
|
||||
id=cls.dashboard_schedule
|
||||
).delete()
|
||||
db.session.commit()
|
||||
with app.app_context():
|
||||
db.session.query(SliceEmailSchedule).filter_by(
|
||||
id=cls.slice_schedule
|
||||
).delete()
|
||||
db.session.query(DashboardEmailSchedule).filter_by(
|
||||
id=cls.dashboard_schedule
|
||||
).delete()
|
||||
db.session.commit()
|
||||
|
||||
def test_crontab_scheduler(self):
|
||||
crontab = "* * * * *"
|
||||
|
|
|
@ -310,6 +310,7 @@ class RolePermissionTests(SupersetTestCase):
|
|||
["Superset", "welcome"],
|
||||
["SecurityApi", "login"],
|
||||
["SecurityApi", "refresh"],
|
||||
["SupersetIndexView", "index"],
|
||||
]
|
||||
unsecured_views = []
|
||||
for view_class in appbuilder.baseviews:
|
||||
|
|
|
@ -60,7 +60,11 @@ class SqlValidatorEndpointTests(SupersetTestCase):
|
|||
self.assertIn("no SQL validator is configured", resp["error"])
|
||||
|
||||
@patch("superset.views.core.get_validator_by_name")
|
||||
@patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True)
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
PRESTO_TEST_FEATURE_FLAGS,
|
||||
clear=True,
|
||||
)
|
||||
def test_validate_sql_endpoint_mocked(self, get_validator_by_name):
|
||||
"""Assert that, with a mocked validator, annotations make it back out
|
||||
from the validate_sql_json endpoint as a list of json dictionaries"""
|
||||
|
@ -87,7 +91,11 @@ class SqlValidatorEndpointTests(SupersetTestCase):
|
|||
self.assertIn("expected,", resp[0]["message"])
|
||||
|
||||
@patch("superset.views.core.get_validator_by_name")
|
||||
@patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True)
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
PRESTO_TEST_FEATURE_FLAGS,
|
||||
clear=True,
|
||||
)
|
||||
def test_validate_sql_endpoint_failure(self, get_validator_by_name):
|
||||
"""Assert that validate_sql_json errors out when the selected validator
|
||||
raises an unexpected exception"""
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
"""Unit tests for Sql Lab"""
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from random import random
|
||||
|
||||
import prison
|
||||
|
||||
|
@ -294,19 +295,19 @@ class SqlLabTests(SupersetTestCase):
|
|||
examples_dbid = get_example_database().id
|
||||
payload = {
|
||||
"chartType": "dist_bar",
|
||||
"datasourceName": "test_viz_flow_table",
|
||||
"datasourceName": f"test_viz_flow_table_{random()}",
|
||||
"schema": "superset",
|
||||
"columns": [
|
||||
{
|
||||
"is_date": False,
|
||||
"type": "STRING",
|
||||
"name": "viz_type",
|
||||
"name": f"viz_type_{random()}",
|
||||
"is_dim": True,
|
||||
},
|
||||
{
|
||||
"is_date": False,
|
||||
"type": "OBJECT",
|
||||
"name": "ccount",
|
||||
"name": f"ccount_{random()}",
|
||||
"is_dim": True,
|
||||
"agg": "sum",
|
||||
},
|
||||
|
@ -421,3 +422,4 @@ class SqlLabTests(SupersetTestCase):
|
|||
{"examples", "fake_db_100"},
|
||||
{r.get("database_name") for r in self.get_json_resp(url)["result"]},
|
||||
)
|
||||
self.delete_fake_db()
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Here is where we create the app which ends up being shared across all tests. A future
|
||||
optimization will be to create a separate app instance for each test class.
|
||||
"""
|
||||
from superset.app import create_app
|
||||
|
||||
app = create_app()
|
|
@ -28,6 +28,7 @@ from sqlalchemy.exc import ArgumentError
|
|||
from superset import app, db, security_manager
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.models.core import Database
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
from superset.utils.core import (
|
||||
base_json_conv,
|
||||
convert_legacy_filters_into_adhoc,
|
||||
|
@ -45,7 +46,6 @@ from superset.utils.core import (
|
|||
parse_human_timedelta,
|
||||
parse_js_uri_path_item,
|
||||
parse_past_timedelta,
|
||||
setup_cache,
|
||||
split,
|
||||
TimeRangeEndpoint,
|
||||
validate_json,
|
||||
|
@ -53,6 +53,7 @@ from superset.utils.core import (
|
|||
zlib_decompress,
|
||||
)
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
def mock_parse_human_datetime(s):
|
||||
|
@ -93,7 +94,7 @@ def mock_to_adhoc(filt, expressionType="SIMPLE", clause="where"):
|
|||
return result
|
||||
|
||||
|
||||
class UtilsTestCase(unittest.TestCase):
|
||||
class UtilsTestCase(SupersetTestCase):
|
||||
def test_json_int_dttm_ser(self):
|
||||
dttm = datetime(2020, 1, 1)
|
||||
ts = 1577836800000.0
|
||||
|
@ -809,12 +810,12 @@ class UtilsTestCase(unittest.TestCase):
|
|||
def test_setup_cache_no_config(self):
|
||||
app = Flask(__name__)
|
||||
cache_config = None
|
||||
self.assertIsNone(setup_cache(app, cache_config))
|
||||
self.assertIsNone(CacheManager._setup_cache(app, cache_config))
|
||||
|
||||
def test_setup_cache_null_config(self):
|
||||
app = Flask(__name__)
|
||||
cache_config = {"CACHE_TYPE": "null"}
|
||||
self.assertIsNone(setup_cache(app, cache_config))
|
||||
self.assertIsNone(CacheManager._setup_cache(app, cache_config))
|
||||
|
||||
def test_setup_cache_standard_config(self):
|
||||
app = Flask(__name__)
|
||||
|
@ -824,7 +825,7 @@ class UtilsTestCase(unittest.TestCase):
|
|||
"CACHE_KEY_PREFIX": "superset_results",
|
||||
"CACHE_REDIS_URL": "redis://localhost:6379/0",
|
||||
}
|
||||
assert isinstance(setup_cache(app, cache_config), Cache) is True
|
||||
assert isinstance(CacheManager._setup_cache(app, cache_config), Cache) is True
|
||||
|
||||
def test_setup_cache_custom_function(self):
|
||||
app = Flask(__name__)
|
||||
|
@ -833,7 +834,9 @@ class UtilsTestCase(unittest.TestCase):
|
|||
def init_cache(app):
|
||||
return CustomCache(app, {})
|
||||
|
||||
assert isinstance(setup_cache(app, init_cache), CustomCache) is True
|
||||
assert (
|
||||
isinstance(CacheManager._setup_cache(app, init_cache), CustomCache) is True
|
||||
)
|
||||
|
||||
def test_get_stacktrace(self):
|
||||
with app.app_context():
|
||||
|
@ -879,6 +882,8 @@ class UtilsTestCase(unittest.TestCase):
|
|||
get_or_create_db("test_db", "sqlite:///changed.db")
|
||||
database = db.session.query(Database).filter_by(database_name="test_db").one()
|
||||
self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db")
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_or_create_db_invalid_uri(self):
|
||||
with self.assertRaises(ArgumentError):
|
||||
|
|
2
tox.ini
2
tox.ini
|
@ -19,7 +19,7 @@ commands =
|
|||
{toxinidir}/superset/bin/superset db upgrade
|
||||
{toxinidir}/superset/bin/superset init
|
||||
nosetests tests/load_examples_test.py
|
||||
nosetests -e load_examples_test {posargs}
|
||||
nosetests -e load_examples_test tests {posargs}
|
||||
deps =
|
||||
-rrequirements.txt
|
||||
-rrequirements-dev.txt
|
||||
|
|
Loading…
Reference in New Issue