Flask App factory PR #1 (#8418)

* First cut at app factory

* Setting things back to master

* Working with new FLASK_APP

* Still need to refactor Celery

* CLI mostly working

* Working on unit tests

* Moving cli stuff around a bit

* Removing get in config

* Defaulting test config

* Adding flask-testing

* flask-testing casing

* resultsbackend property bug

* Fixing up cli

* Quick fix for KV api

* Working on save slice

* Fixed core_tests

* Fixed utils_tests

* Most tests working - still need to dig into remaining app_context issue in tests

* All tests passing locally - need to update code comments

* Fixing dashboard tests again

* Blacking

* Sorting imports

* linting

* removing envvar mangling

* blacking

* Fixing unit tests

* isorting

* licensing

* fixing mysql tests

* fixing cypress?

* fixing .flaskenv

* fixing test app_ctx

* fixing cypress

* moving manifest processor around

* moving results backend manager around

* Cleaning up __init__ a bit more

* Addressing PR comments

* Addressing PR comments

* Blacking

* Fixes for running celery worker

* Tuning isort

* Blacking
This commit is contained in:
Craig Rueda 2019-11-20 07:47:06 -08:00 committed by Daniel Vaz Gaspar
parent 300c4ecb0f
commit e490414484
38 changed files with 992 additions and 570 deletions

View File

@ -14,5 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
FLASK_APP=superset:app
FLASK_ENV=development
FLASK_APP="superset.app:create_app()"
FLASK_ENV="development"

View File

@ -17,6 +17,7 @@
black==19.3b0
coverage==4.5.3
flask-cors==3.0.7
flask-testing==0.7.1
ipdb==0.12
isort==4.3.21
mypy==0.670

View File

@ -128,4 +128,5 @@ setup(
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
],
tests_require=["flask-testing==0.7.1"],
)

View File

@ -14,229 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
"""Package's main module!"""
import json
import logging
import os
from copy import deepcopy
from typing import Any, Dict
from flask import current_app, Flask
from werkzeug.local import LocalProxy
import wtforms_json
from flask import Flask, redirect
from flask_appbuilder import AppBuilder, IndexView, SQLA
from flask_appbuilder.baseviews import expose
from flask_compress import Compress
from flask_migrate import Migrate
from flask_talisman import Talisman
from flask_wtf.csrf import CSRFProtect
from superset import config
from superset.app import create_app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import (
appbuilder,
cache_manager,
db,
event_logger,
feature_flag_manager,
manifest_processor,
results_backend_manager,
security_manager,
talisman,
)
from superset.security import SupersetSecurityManager
from superset.utils.core import pessimistic_connection_handling, setup_cache
from superset.utils.log import get_event_logger_from_cfg_value
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
wtforms_json.init()
APP_DIR = os.path.dirname(__file__)
CONFIG_MODULE = os.environ.get("SUPERSET_CONFIG", "superset.config")
if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR)
app = Flask(__name__)
app.config.from_object(CONFIG_MODULE) # type: ignore
conf = app.config
#################################################################
# Handling manifest file logic at app start
#################################################################
MANIFEST_FILE = APP_DIR + "/static/assets/dist/manifest.json"
manifest: Dict[Any, Any] = {}
def parse_manifest_json():
global manifest
try:
with open(MANIFEST_FILE, "r") as f:
# the manifest inclues non-entry files
# we only need entries in templates
full_manifest = json.load(f)
manifest = full_manifest.get("entrypoints", {})
except Exception:
pass
def get_js_manifest_files(filename):
if app.debug:
parse_manifest_json()
entry_files = manifest.get(filename, {})
return entry_files.get("js", [])
def get_css_manifest_files(filename):
if app.debug:
parse_manifest_json()
entry_files = manifest.get(filename, {})
return entry_files.get("css", [])
def get_unloaded_chunks(files, loaded_chunks):
filtered_files = [f for f in files if f not in loaded_chunks]
for f in filtered_files:
loaded_chunks.add(f)
return filtered_files
parse_manifest_json()
@app.context_processor
def get_manifest():
return dict(
loaded_chunks=set(),
get_unloaded_chunks=get_unloaded_chunks,
js_manifest=get_js_manifest_files,
css_manifest=get_css_manifest_files,
)
#################################################################
for bp in conf["BLUEPRINTS"]:
try:
print("Registering blueprint: '{}'".format(bp.name))
app.register_blueprint(bp)
except Exception as e:
print("blueprint registration failed")
logging.exception(e)
if conf.get("SILENCE_FAB"):
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
db = SQLA(app)
if conf.get("WTF_CSRF_ENABLED"):
csrf = CSRFProtect(app)
csrf_exempt_list = conf.get("WTF_CSRF_EXEMPT_LIST", [])
for ex in csrf_exempt_list:
csrf.exempt(ex)
pessimistic_connection_handling(db.engine)
cache = setup_cache(app, conf.get("CACHE_CONFIG"))
tables_cache = setup_cache(app, conf.get("TABLE_NAMES_CACHE_CONFIG"))
migrate = Migrate(app, db, directory=APP_DIR + "/migrations")
app.config["LOGGING_CONFIGURATOR"].configure_logging(app.config, app.debug)
if app.config["ENABLE_CORS"]:
from flask_cors import CORS
CORS(app, **app.config["CORS_OPTIONS"])
if app.config["ENABLE_PROXY_FIX"]:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix( # type: ignore
app.wsgi_app, **app.config["PROXY_FIX_CONFIG"]
)
if app.config["ENABLE_CHUNK_ENCODING"]:
class ChunkedEncodingFix(object):
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == u"chunked":
environ["wsgi.input_terminated"] = True
return self.app(environ, start_response)
app.wsgi_app = ChunkedEncodingFix(app.wsgi_app) # type: ignore
if app.config["UPLOAD_FOLDER"]:
try:
os.makedirs(app.config["UPLOAD_FOLDER"])
except OSError:
pass
for middleware in app.config["ADDITIONAL_MIDDLEWARE"]:
app.wsgi_app = middleware(app.wsgi_app) # type: ignore
class MyIndexView(IndexView):
@expose("/")
def index(self):
return redirect("/superset/welcome")
custom_sm = app.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager
if not issubclass(custom_sm, SupersetSecurityManager):
raise Exception(
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
not FAB's security manager.
See [4565] in UPDATING.md"""
)
with app.app_context():
appbuilder = AppBuilder(
app,
db.session,
base_template="superset/base.html",
indexview=MyIndexView,
security_manager_class=custom_sm,
update_perms=False, # Run `superset init` to update FAB's perms
)
security_manager = appbuilder.sm
results_backend = app.config["RESULTS_BACKEND"]
results_backend_use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
# Merge user defined feature flags with default feature flags
_feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
_feature_flags.update(app.config["FEATURE_FLAGS"])
# Event Logger
event_logger = get_event_logger_from_cfg_value(app.config["EVENT_LOGGER"])
def get_feature_flags():
GET_FEATURE_FLAGS_FUNC = app.config["GET_FEATURE_FLAGS_FUNC"]
if GET_FEATURE_FLAGS_FUNC:
return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags))
return _feature_flags
def is_feature_enabled(feature):
"""Utility function for checking whether a feature is turned on"""
return get_feature_flags().get(feature)
# Flask-Compress
if conf.get("ENABLE_FLASK_COMPRESS"):
Compress(app)
talisman = Talisman()
if app.config["TALISMAN_ENABLED"]:
talisman.init_app(app, **app.config["TALISMAN_CONFIG"])
# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = app.config["FLASK_APP_MUTATOR"]
if flask_app_mutator:
flask_app_mutator(app)
from superset import views # noqa isort:skip
# Registering sources
module_datasource_map = app.config["DEFAULT_MODULE_DS_MAP"]
module_datasource_map.update(app.config["ADDITIONAL_MODULE_DS_MAP"])
ConnectorRegistry.register_sources(module_datasource_map)
# All of the fields located here should be considered legacy. The correct way
# to declare "global" dependencies is to define it in extensions.py,
# then initialize it in app.create_app(). These fields will be removed
# in subsequent PRs as things are migrated towards the factory pattern
app: Flask = current_app
cache = LocalProxy(lambda: cache_manager.cache)
conf = LocalProxy(lambda: current_app.config)
get_feature_flags = feature_flag_manager.get_feature_flags
get_css_manifest_files = manifest_processor.get_css_manifest_files
is_feature_enabled = feature_flag_manager.is_feature_enabled
results_backend = LocalProxy(lambda: results_backend_manager.results_backend)
results_backend_use_msgpack = LocalProxy(
lambda: results_backend_manager.should_use_msgpack
)
tables_cache = LocalProxy(lambda: cache_manager.tables_cache)

260
superset/app.py Normal file
View File

@ -0,0 +1,260 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import wtforms_json
from flask import Flask, redirect
from flask_appbuilder import expose, IndexView
from flask_compress import Compress
from flask_wtf import CSRFProtect
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import (
_event_logger,
APP_DIR,
appbuilder,
cache_manager,
celery_app,
db,
feature_flag_manager,
manifest_processor,
migrate,
results_backend_manager,
talisman,
)
from superset.security import SupersetSecurityManager
from superset.utils.core import pessimistic_connection_handling
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
logger = logging.getLogger(__name__)
def create_app():
app = Flask(__name__)
try:
# Allow user to override our config completely
config_module = os.environ.get("SUPERSET_CONFIG", "superset.config")
app.config.from_object(config_module)
app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app)
app_initializer.init_app()
return app
# Make sure that bootstrap errors ALWAYS get logged
except Exception as ex:
logger.exception("Failed to create app")
raise ex
class SupersetIndexView(IndexView):
@expose("/")
def index(self):
return redirect("/superset/welcome")
class SupersetAppInitializer:
def __init__(self, app: Flask) -> None:
super().__init__()
self.flask_app = app
self.config = app.config
self.manifest: dict = {}
def pre_init(self) -> None:
"""
Called after all other init tasks are complete
"""
wtforms_json.init()
if not os.path.exists(self.config["DATA_DIR"]):
os.makedirs(self.config["DATA_DIR"])
def post_init(self) -> None:
"""
Called before any other init tasks
"""
pass
def configure_celery(self) -> None:
celery_app.config_from_object(self.config["CELERY_CONFIG"])
celery_app.set_default()
@staticmethod
def init_views() -> None:
# TODO - This should iterate over all views and register them with FAB...
from superset import views # noqa pylint: disable=unused-variable
def init_app_in_ctx(self) -> None:
"""
Runs init logic in the context of the app
"""
self.configure_feature_flags()
self.configure_fab()
self.configure_data_sources()
# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
if flask_app_mutator:
flask_app_mutator(self.flask_app)
self.init_views()
def init_app(self) -> None:
"""
Main entry point which will delegate to other methods in
order to fully init the app
"""
self.pre_init()
self.setup_db()
self.configure_celery()
self.setup_event_logger()
self.setup_bundle_manifest()
self.register_blueprints()
self.configure_wtf()
self.configure_logging()
self.configure_middlewares()
self.configure_cache()
with self.flask_app.app_context():
self.init_app_in_ctx()
self.post_init()
def setup_event_logger(self):
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
self.flask_app.config.get("EVENT_LOGGER", DBEventLogger())
)
def configure_data_sources(self):
# Registering sources
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
ConnectorRegistry.register_sources(module_datasource_map)
def configure_cache(self):
cache_manager.init_app(self.flask_app)
results_backend_manager.init_app(self.flask_app)
def configure_feature_flags(self):
feature_flag_manager.init_app(self.flask_app)
def configure_fab(self):
if self.config["SILENCE_FAB"]:
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
custom_sm = self.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager
if not issubclass(custom_sm, SupersetSecurityManager):
raise Exception(
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
not FAB's security manager.
See [4565] in UPDATING.md"""
)
appbuilder.indexview = SupersetIndexView
appbuilder.base_template = "superset/base.html"
appbuilder.security_manager_class = custom_sm
appbuilder.update_perms = False
appbuilder.init_app(self.flask_app, db.session)
def configure_middlewares(self):
if self.config["ENABLE_CORS"]:
from flask_cors import CORS
CORS(self.flask_app, **self.config["CORS_OPTIONS"])
if self.config["ENABLE_PROXY_FIX"]:
from werkzeug.middleware.proxy_fix import ProxyFix
self.flask_app.wsgi_app = ProxyFix(
self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"]
)
if self.config["ENABLE_CHUNK_ENCODING"]:
class ChunkedEncodingFix(object): # pylint: disable=too-few-public-methods
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked":
environ["wsgi.input_terminated"] = True
return self.app(environ, start_response)
self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app)
if self.config["UPLOAD_FOLDER"]:
try:
os.makedirs(self.config["UPLOAD_FOLDER"])
except OSError:
pass
for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app)
# Flask-Compress
if self.config["ENABLE_FLASK_COMPRESS"]:
Compress(self.flask_app)
if self.config["TALISMAN_ENABLED"]:
talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"])
def configure_logging(self):
self.config["LOGGING_CONFIGURATOR"].configure_logging(
self.config, self.flask_app.debug
)
def setup_db(self):
db.init_app(self.flask_app)
with self.flask_app.app_context():
pessimistic_connection_handling(db.engine)
migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations")
def configure_wtf(self):
if self.config["WTF_CSRF_ENABLED"]:
csrf = CSRFProtect(self.flask_app)
csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"]
for ex in csrf_exempt_list:
csrf.exempt(ex)
def register_blueprints(self):
for bp in self.config["BLUEPRINTS"]:
try:
logger.info(f"Registering blueprint: '{bp.name}'")
self.flask_app.register_blueprint(bp)
except Exception: # pylint: disable=broad-except
logger.exception("blueprint registration failed")
def setup_bundle_manifest(self):
manifest_processor.init_app(self.flask_app)

View File

@ -15,17 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import click
from flask.cli import FlaskGroup
from superset.cli import create_app
@click.group(cls=FlaskGroup, create_app=create_app)
def cli():
"""This is a management script for the Superset application."""
pass
from superset.cli import superset
if __name__ == '__main__':
cli()
superset()

View File

@ -25,27 +25,28 @@ import click
import yaml
from colorama import Fore, Style
from flask import g
from flask.cli import FlaskGroup, with_appcontext
from flask_appbuilder import Model
from pathlib2 import Path
from superset import app, appbuilder, db, examples, security_manager
from superset.common.tags import add_favorites, add_owners, add_types
from superset.utils import core as utils, dashboard_import_export, dict_import_export
config = app.config
celery_app = utils.get_celery_app(config)
from superset import app, appbuilder, security_manager
from superset.app import create_app
from superset.extensions import celery_app, db
from superset.utils import core as utils
def create_app(script_info=None):
return app
@click.group(cls=FlaskGroup, create_app=create_app)
@with_appcontext
def superset():
"""This is a management script for the Superset application."""
@app.shell_context_processor
def make_shell_context():
return dict(app=app, db=db)
@app.shell_context_processor
def make_shell_context():
return dict(app=app, db=db)
@app.cli.command()
@superset.command()
@with_appcontext
def init():
"""Inits the Superset application"""
utils.get_example_database()
@ -53,7 +54,8 @@ def init():
security_manager.sync_role_definitions()
@app.cli.command()
@superset.command()
@with_appcontext
@click.option("--verbose", "-v", is_flag=True, help="Show extra information")
def version(verbose):
"""Prints the current version number"""
@ -62,7 +64,7 @@ def version(verbose):
Fore.YELLOW
+ "Superset "
+ Fore.CYAN
+ "{version}".format(version=config["VERSION_STRING"])
+ "{version}".format(version=app.config["VERSION_STRING"])
)
print(Fore.BLUE + "-=" * 15)
if verbose:
@ -77,6 +79,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
examples_db = utils.get_example_database()
print(f"Loading examples metadata and related data into {examples_db}")
from superset import examples
examples.load_css_templates()
print("Loading energy related dataset")
@ -129,7 +133,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False):
examples.load_tabbed_dashboard(only_metadata)
@app.cli.command()
@with_appcontext
@superset.command()
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
@click.option(
"--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data"
@ -142,7 +147,8 @@ def load_examples(load_test_data, only_metadata=False, force=False):
load_examples_run(load_test_data, only_metadata, force)
@app.cli.command()
@with_appcontext
@superset.command()
@click.option("--database_name", "-d", help="Database name to change")
@click.option("--uri", "-u", help="Database URI to change")
def set_database_uri(database_name, uri):
@ -150,7 +156,8 @@ def set_database_uri(database_name, uri):
utils.get_or_create_db(database_name, uri)
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--datasource",
"-d",
@ -180,7 +187,8 @@ def refresh_druid(datasource, merge):
session.commit()
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--path",
"-p",
@ -202,6 +210,8 @@ def refresh_druid(datasource, merge):
)
def import_dashboards(path, recursive, username):
"""Import dashboards from JSON"""
from superset.utils import dashboard_import_export
p = Path(path)
files = []
if p.is_file():
@ -222,7 +232,8 @@ def import_dashboards(path, recursive, username):
logging.error(e)
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--dashboard-file", "-f", default=None, help="Specify the the file to export to"
)
@ -231,6 +242,8 @@ def import_dashboards(path, recursive, username):
)
def export_dashboards(print_stdout, dashboard_file):
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export
data = dashboard_import_export.export_dashboards(db.session)
if print_stdout or not dashboard_file:
print(data)
@ -240,7 +253,8 @@ def export_dashboards(print_stdout, dashboard_file):
data_stream.write(data)
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--path",
"-p",
@ -265,6 +279,8 @@ def export_dashboards(print_stdout, dashboard_file):
)
def import_datasources(path, sync, recursive):
"""Import datasources from YAML"""
from superset.utils import dict_import_export
sync_array = sync.split(",")
p = Path(path)
files = []
@ -288,7 +304,8 @@ def import_datasources(path, sync, recursive):
logging.error(e)
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--datasource-file", "-f", default=None, help="Specify the the file to export to"
)
@ -313,6 +330,8 @@ def export_datasources(
print_stdout, datasource_file, back_references, include_defaults
):
"""Export datasources to YAML"""
from superset.utils import dict_import_export
data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
@ -327,7 +346,8 @@ def export_datasources(
yaml.safe_dump(data, data_stream, default_flow_style=False)
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--back-references",
"-b",
@ -337,11 +357,14 @@ def export_datasources(
)
def export_datasource_schema(back_references):
"""Export datasource YAML schema to stdout"""
from superset.utils import dict_import_export
data = dict_import_export.export_schema_to_dict(back_references=back_references)
yaml.safe_dump(data, stdout, default_flow_style=False)
@app.cli.command()
@superset.command()
@with_appcontext
def update_datasources_cache():
"""Refresh sqllab datasources cache"""
from superset.models.core import Database
@ -360,7 +383,8 @@ def update_datasources_cache():
print("{}".format(str(e)))
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"--workers", "-w", type=int, help="Number of celery server workers to fire up"
)
@ -372,14 +396,17 @@ def worker(workers):
)
if workers:
celery_app.conf.update(CELERYD_CONCURRENCY=workers)
elif config["SUPERSET_CELERY_WORKERS"]:
celery_app.conf.update(CELERYD_CONCURRENCY=config["SUPERSET_CELERY_WORKERS"])
elif app.config["SUPERSET_CELERY_WORKERS"]:
celery_app.conf.update(
CELERYD_CONCURRENCY=app.config["SUPERSET_CELERY_WORKERS"]
)
worker = celery_app.Worker(optimization="fair")
worker.start()
@app.cli.command()
@superset.command()
@with_appcontext
@click.option(
"-p", "--port", default="5555", help="Port on which to start the Flower process"
)
@ -409,7 +436,8 @@ def flower(port, address):
Popen(cmd, shell=True).wait()
@app.cli.command()
@superset.command()
@with_appcontext
def load_test_users():
"""
Loads admin, alpha, and gamma user for testing purposes
@ -426,7 +454,7 @@ def load_test_users_run():
Syncs permissions for those users/roles
"""
if config["TESTING"]:
if app.config["TESTING"]:
sm = security_manager
@ -463,11 +491,15 @@ def load_test_users_run():
sm.get_session.commit()
@app.cli.command()
@superset.command()
@with_appcontext
def sync_tags():
"""Rebuilds special tags (owner, type, favorited by)."""
# pylint: disable=no-member
metadata = Model.metadata
from superset.common.tags import add_favorites, add_owners, add_types
add_types(db.engine, metadata)
add_owners(db.engine, metadata)
add_favorites(db.engine, metadata)

113
superset/extensions.py Normal file
View File

@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import celery
from flask_appbuilder import AppBuilder, SQLA
from flask_migrate import Migrate
from flask_talisman import Talisman
from werkzeug.local import LocalProxy
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager
class ResultsBackendManager:
def __init__(self) -> None:
super().__init__()
self._results_backend = None
self._use_msgpack = False
def init_app(self, app):
self._results_backend = app.config.get("RESULTS_BACKEND")
self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK")
@property
def results_backend(self):
return self._results_backend
@property
def should_use_msgpack(self):
return self._use_msgpack
class UIManifestProcessor:
def __init__(self, app_dir: str) -> None:
super().__init__()
self.app = None
self.manifest: dict = {}
self.manifest_file = f"{app_dir}/static/assets/dist/manifest.json"
def init_app(self, app):
self.app = app
# Preload the cache
self.parse_manifest_json()
@app.context_processor
def get_manifest(): # pylint: disable=unused-variable
return dict(
loaded_chunks=set(),
get_unloaded_chunks=self.get_unloaded_chunks,
js_manifest=self.get_js_manifest_files,
css_manifest=self.get_css_manifest_files,
)
def parse_manifest_json(self):
try:
with open(self.manifest_file, "r") as f:
# the manifest includes non-entry files
# we only need entries in templates
full_manifest = json.load(f)
self.manifest = full_manifest.get("entrypoints", {})
except Exception: # pylint: disable=broad-except
pass
def get_js_manifest_files(self, filename):
if self.app.debug:
self.parse_manifest_json()
entry_files = self.manifest.get(filename, {})
return entry_files.get("js", [])
def get_css_manifest_files(self, filename):
if self.app.debug:
self.parse_manifest_json()
entry_files = self.manifest.get(filename, {})
return entry_files.get("css", [])
@staticmethod
def get_unloaded_chunks(files, loaded_chunks):
filtered_files = [f for f in files if f not in loaded_chunks]
for f in filtered_files:
loaded_chunks.add(f)
return filtered_files
APP_DIR = os.path.dirname(__file__)
appbuilder = AppBuilder(update_perms=False)
cache_manager = CacheManager()
celery_app = celery.Celery()
db = SQLA()
_event_logger: dict = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
manifest_processor = UIManifestProcessor(APP_DIR)
migrate = Migrate()
results_backend_manager = ResultsBackendManager()
security_manager = LocalProxy(lambda: appbuilder.sm)
talisman = Talisman()

View File

@ -19,10 +19,6 @@
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from wtforms import Field
from superset import app
config = app.config
class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()

View File

@ -42,9 +42,9 @@ from superset import (
)
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.tasks.celery_app import app as celery_app
from superset.utils.core import json_iso_dttm_ser, QueryStatus, sources, zlib_compress
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing

View File

@ -15,4 +15,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import cache, schedules

View File

@ -25,9 +25,9 @@ from celery.utils.log import get_task_logger
from sqlalchemy import and_, func
from superset import app, db
from superset.extensions import celery_app
from superset.models.core import Dashboard, Log, Slice
from superset.models.tags import Tag, TaggedObject
from superset.tasks.celery_app import app as celery_app
from superset.utils.core import parse_human_datetime
logger = get_task_logger(__name__)

View File

@ -16,12 +16,20 @@
# under the License.
# pylint: disable=C,R,W
"""Utility functions used across Superset"""
"""
This is the main entrypoint used by Celery workers. As such,
it needs to call create_app() in order to initialize things properly
"""
# Superset framework imports
from superset import app
from superset.utils.core import get_celery_app
from superset import create_app
from superset.extensions import celery_app
# Globals
config = app.config
app = get_celery_app(config)
# Init the Flask app / configure everything
create_app()
# Need to import late, as the celery_app will have been setup by "create_app()"
from . import cache, schedules # isort:skip
# Export the celery app globally for Celery (as run on the cmd line) to find
app = celery_app

View File

@ -39,13 +39,13 @@ from werkzeug.http import parse_cookie
# Superset framework imports
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.schedules import (
EmailDeliveryType,
get_scheduler_model,
ScheduleType,
SliceEmailReportFormat,
)
from superset.tasks.celery_app import app as celery_app
from superset.utils.core import get_email_address_list, send_email_smtp
# Globals

View File

@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
from flask import request
from typing import Optional
from superset import tables_cache
from flask import Flask, request
from flask_caching import Cache
from superset.extensions import cache_manager
def view_cache_key(*unused_args, **unused_kwargs) -> str:
@ -43,7 +46,7 @@ def memoized_func(key=view_cache_key, attribute_in_key=None):
"""
def wrap(f):
if tables_cache:
if cache_manager.tables_cache:
def wrapped_f(self, *args, **kwargs):
if not kwargs.get("cache", True):
@ -55,11 +58,13 @@ def memoized_func(key=view_cache_key, attribute_in_key=None):
)
else:
cache_key = key(*args, **kwargs)
o = tables_cache.get(cache_key)
o = cache_manager.tables_cache.get(cache_key)
if not kwargs.get("force") and o is not None:
return o
o = f(self, *args, **kwargs)
tables_cache.set(cache_key, o, timeout=kwargs.get("cache_timeout"))
cache_manager.tables_cache.set(
cache_key, o, timeout=kwargs.get("cache_timeout")
)
return o
else:

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional
from flask import Flask
from flask_caching import Cache
class CacheManager:
def __init__(self) -> None:
super().__init__()
self._tables_cache = None
self._cache = None
def init_app(self, app):
self._cache = self._setup_cache(app, app.config.get("CACHE_CONFIG"))
self._tables_cache = self._setup_cache(
app, app.config.get("TABLE_NAMES_CACHE_CONFIG")
)
@staticmethod
def _setup_cache(app: Flask, cache_config) -> Optional[Cache]:
"""Setup the flask-cache on a flask app"""
if cache_config:
if isinstance(cache_config, dict):
if cache_config.get("CACHE_TYPE") != "null":
return Cache(app, config=cache_config)
else:
# Accepts a custom cache initialization function,
# returning an object compatible with Flask-Caching API
return cache_config(app)
return None
@property
def tables_cache(self):
return self._tables_cache
@property
def cache(self):
return self._cache

View File

@ -791,20 +791,6 @@ def choicify(values):
return [(v, v) for v in values]
def setup_cache(app: Flask, cache_config) -> Optional[Cache]:
"""Setup the flask-cache on a flask app"""
if cache_config:
if isinstance(cache_config, dict):
if cache_config["CACHE_TYPE"] != "null":
return Cache(app, config=cache_config)
else:
# Accepts a custom cache initialization function,
# returning an object compatible with Flask-Caching API
return cache_config(app)
return None
def zlib_compress(data):
"""
Compress things in a py2/3 safe fashion
@ -832,19 +818,6 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
return decompressed.decode("utf-8") if decode else decompressed
_celery_app = None
def get_celery_app(config):
global _celery_app
if _celery_app:
return _celery_app
_celery_app = celery.Celery()
_celery_app.config_from_object(config["CELERY_CONFIG"])
_celery_app.set_default()
return _celery_app
def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
result = {
"clause": clause.upper(),

View File

@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
class FeatureFlagManager:
def __init__(self) -> None:
super().__init__()
self._get_feature_flags_func = None
self._feature_flags = None
def init_app(self, app):
self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC")
self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
def get_feature_flags(self):
if self._get_feature_flags_func:
return self._get_feature_flags_func(deepcopy(self._feature_flags))
return self._feature_flags
def is_feature_enabled(self, feature):
"""Utility function for checking whether a feature is turned on"""
return self.get_feature_flags().get(feature)

View File

@ -686,7 +686,9 @@ class KV(BaseSupersetView):
def get_value(self, key_id):
kv = None
try:
kv = db.session.query(models.KeyValue).filter_by(id=key_id).one()
kv = db.session.query(models.KeyValue).filter_by(id=key_id).scalar()
if not kv:
return Response(status=404, content_type="text/plain")
except Exception as e:
return json_error_response(e)
return Response(kv.value, status=200, content_type="text/plain")
@ -736,6 +738,8 @@ appbuilder.add_view_no_menu(R)
class Superset(BaseSupersetView):
"""The base views for Superset!"""
logger = logging.getLogger(__name__)
@has_access_api
@expose("/datasources/")
def datasources(self):
@ -2059,6 +2063,7 @@ class Superset(BaseSupersetView):
)
obj.get_json()
except Exception as e:
self.logger.exception("Failed to warm up cache")
return json_error_response(utils.error_msg_from_exception(e))
return json_success(
json.dumps(

View File

@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
from unittest import mock
from superset import app, db, security_manager
from tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.druid.models import DruidDatasource
from superset.connectors.sqla.models import SqlaTable
@ -94,22 +96,24 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
class RequestAccessTests(SupersetTestCase):
@classmethod
def setUpClass(cls):
security_manager.add_role("override_me")
security_manager.add_role(TEST_ROLE_1)
security_manager.add_role(TEST_ROLE_2)
security_manager.add_role(DB_ACCESS_ROLE)
security_manager.add_role(SCHEMA_ACCESS_ROLE)
db.session.commit()
with app.app_context():
security_manager.add_role("override_me")
security_manager.add_role(TEST_ROLE_1)
security_manager.add_role(TEST_ROLE_2)
security_manager.add_role(DB_ACCESS_ROLE)
security_manager.add_role(SCHEMA_ACCESS_ROLE)
db.session.commit()
@classmethod
def tearDownClass(cls):
override_me = security_manager.find_role("override_me")
db.session.delete(override_me)
db.session.delete(security_manager.find_role(TEST_ROLE_1))
db.session.delete(security_manager.find_role(TEST_ROLE_2))
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
db.session.commit()
with app.app_context():
override_me = security_manager.find_role("override_me")
db.session.delete(override_me)
db.session.delete(security_manager.find_role(TEST_ROLE_1))
db.session.delete(security_manager.find_role(TEST_ROLE_2))
db.session.delete(security_manager.find_role(DB_ACCESS_ROLE))
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
db.session.commit()
def setUp(self):
self.login("admin")

View File

@ -14,52 +14,57 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import imp
import json
import unittest
from unittest.mock import Mock, patch
from unittest.mock import Mock
import pandas as pd
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from superset import app, db, is_feature_enabled, security_manager
from tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.core import Database
from superset.utils.core import get_example_database
BASE_DIR = app.config["BASE_DIR"]
FAKE_DB_NAME = "fake_db_100"
class SupersetTestCase(unittest.TestCase):
class SupersetTestCase(TestCase):
def __init__(self, *args, **kwargs):
super(SupersetTestCase, self).__init__(*args, **kwargs)
self.client = app.test_client()
self.maxDiff = None
def create_app(self):
return app
@classmethod
def create_druid_test_objects(cls):
# create druid cluster and druid datasources
session = db.session
cluster = (
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
)
if not cluster:
cluster = DruidCluster(cluster_name="druid_test")
session.add(cluster)
session.commit()
with app.app_context():
session = db.session
cluster = (
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
)
if not cluster:
cluster = DruidCluster(cluster_name="druid_test")
session.add(cluster)
session.commit()
druid_datasource1 = DruidDatasource(
datasource_name="druid_ds_1", cluster_name="druid_test"
)
session.add(druid_datasource1)
druid_datasource2 = DruidDatasource(
datasource_name="druid_ds_2", cluster_name="druid_test"
)
session.add(druid_datasource2)
session.commit()
druid_datasource1 = DruidDatasource(
datasource_name="druid_ds_1", cluster_name="druid_test"
)
session.add(druid_datasource1)
druid_datasource2 = DruidDatasource(
datasource_name="druid_ds_2", cluster_name="druid_test"
)
session.add(druid_datasource2)
session.commit()
def get_table(self, table_id):
return db.session.query(SqlaTable).filter_by(id=table_id).one()
@ -210,7 +215,7 @@ class SupersetTestCase(unittest.TestCase):
def create_fake_db(self):
self.login(username="admin")
database_name = "fake_db_100"
database_name = FAKE_DB_NAME
db_id = 100
extra = """{
"schemas_allowed_for_csv_upload":
@ -225,6 +230,15 @@ class SupersetTestCase(unittest.TestCase):
extra=extra,
)
def delete_fake_db(self):
database = (
db.session.query(Database)
.filter(Database.database_name == FAKE_DB_NAME)
.scalar()
)
if database:
db.session.delete(database)
def validate_sql(
self,
sql,
@ -246,18 +260,6 @@ class SupersetTestCase(unittest.TestCase):
raise Exception("validate_sql failed")
return resp
@patch.dict("superset._feature_flags", {"FOO": True}, clear=True)
def test_existing_feature_flags(self):
self.assertTrue(is_feature_enabled("FOO"))
@patch.dict("superset._feature_flags", {}, clear=True)
def test_nonexistent_feature_flags(self):
self.assertFalse(is_feature_enabled("FOO"))
def test_feature_flags(self):
self.assertEqual(is_feature_enabled("foo"), "bar")
self.assertEqual(is_feature_enabled("super"), "set")
def get_dash_by_slug(self, dash_slug):
sesh = db.session()
return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first()

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset Celery worker"""
import datetime
import json
@ -22,7 +23,8 @@ import time
import unittest
import unittest.mock as mock
from superset import app, db, sql_lab
from tests.test_app import app # isort:skip
from superset import db, sql_lab
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.helpers import QueryStatus
@ -32,20 +34,9 @@ from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
BASE_DIR = app.config["BASE_DIR"]
CELERY_SLEEP_TIME = 5
class CeleryConfig(object):
BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL
CELERY_IMPORTS = ("superset.sql_lab",)
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
CONCURRENCY = 1
app.config["CELERY_CONFIG"] = CeleryConfig
class UtilityFunctionTests(SupersetTestCase):
# TODO(bkyryliuk): support more cases in CTA function.
@ -79,10 +70,6 @@ class UtilityFunctionTests(SupersetTestCase):
class CeleryTestCase(SupersetTestCase):
def __init__(self, *args, **kwargs):
super(CeleryTestCase, self).__init__(*args, **kwargs)
self.client = app.test_client()
def get_query_by_name(self, sql):
session = db.session
query = session.query(Query).filter_by(sql=sql).first()
@ -97,11 +84,22 @@ class CeleryTestCase(SupersetTestCase):
@classmethod
def setUpClass(cls):
db.session.query(Query).delete()
db.session.commit()
with app.app_context():
worker_command = BASE_DIR + "/bin/superset worker -w 2"
subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE)
class CeleryConfig(object):
BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL
CELERY_IMPORTS = ("superset.sql_lab",)
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
CONCURRENCY = 1
app.config["CELERY_CONFIG"] = CeleryConfig
db.session.query(Query).delete()
db.session.commit()
base_dir = app.config["BASE_DIR"]
worker_command = base_dir + "/bin/superset worker -w 2"
subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE)
@classmethod
def tearDownClass(cls):
@ -190,6 +188,7 @@ class CeleryTestCase(SupersetTestCase):
result = self.run_sql(
db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
)
db.session.close()
assert result["query"]["state"] in (
QueryStatus.PENDING,
QueryStatus.RUNNING,
@ -224,6 +223,7 @@ class CeleryTestCase(SupersetTestCase):
result = self.run_sql(
db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
)
db.session.close()
assert result["query"]["state"] in (
QueryStatus.PENDING,
QueryStatus.RUNNING,

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import cgi
import csv
@ -33,7 +34,8 @@ import pandas as pd
import psycopg2
import sqlalchemy as sqla
from superset import app, dataframe, db, jinja_context, security_manager, sql_lab
from tests.test_app import app
from superset import dataframe, db, jinja_context, security_manager, sql_lab
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
@ -51,16 +53,13 @@ class CoreTests(SupersetTestCase):
def __init__(self, *args, **kwargs):
super(CoreTests, self).__init__(*args, **kwargs)
@classmethod
def setUpClass(cls):
cls.table_ids = {
tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all())
}
def setUp(self):
db.session.query(Query).delete()
db.session.query(models.DatasourceAccessRequest).delete()
db.session.query(models.Log).delete()
self.table_ids = {
tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all())
}
def tearDown(self):
db.session.query(Query).delete()
@ -196,12 +195,11 @@ class CoreTests(SupersetTestCase):
def test_save_slice(self):
self.login(username="admin")
slice_name = "Energy Sankey"
slice_name = f"Energy Sankey"
slice_id = self.get_slice(slice_name, db.session).id
db.session.commit()
copy_name = "Test Sankey Save"
copy_name = f"Test Sankey Save_{random.random()}"
tbl_id = self.table_ids.get("energy_usage")
new_slice_name = "Test Sankey Overwirte"
new_slice_name = f"Test Sankey Overwrite_{random.random()}"
url = (
"/superset/explore/table/{}/?slice_name={}&"
@ -216,13 +214,17 @@ class CoreTests(SupersetTestCase):
"slice_id": slice_id,
}
# Changing name and save as a new slice
self.get_resp(
resp = self.client.post(
url.format(tbl_id, copy_name, "saveas"),
{"form_data": json.dumps(form_data)},
data={"form_data": json.dumps(form_data)},
)
slices = db.session.query(models.Slice).filter_by(slice_name=copy_name).all()
assert len(slices) == 1
new_slice_id = slices[0].id
db.session.expunge_all()
new_slice_id = resp.json["form_data"]["slice_id"]
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, copy_name)
form_data.pop("slice_id") # We don't save the slice id when saving as
self.assertEqual(slc.viz.form_data, form_data)
form_data = {
"viz_type": "sankey",
@ -233,14 +235,18 @@ class CoreTests(SupersetTestCase):
"time_range": "now",
}
# Setting the name back to its original name by overwriting new slice
self.get_resp(
self.client.post(
url.format(tbl_id, new_slice_name, "overwrite"),
{"form_data": json.dumps(form_data)},
data={"form_data": json.dumps(form_data)},
)
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).first()
assert slc.slice_name == new_slice_name
assert slc.viz.form_data == form_data
db.session.expunge_all()
slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one()
self.assertEqual(slc.slice_name, new_slice_name)
self.assertEqual(slc.viz.form_data, form_data)
# Cleanup
db.session.delete(slc)
db.session.commit()
def test_filter_endpoint(self):
self.login(username="admin")
@ -406,10 +412,16 @@ class CoreTests(SupersetTestCase):
database = utils.get_example_database()
self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
# Need to clean up after ourselves
database.impersonate_user = False
database.allow_dml = False
database.allow_run_async = False
db.session.commit()
def test_warm_up_cache(self):
slc = self.get_slice("Girls", db.session)
data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
assert data == [{"slice_id": slc.id, "slice_name": slc.slice_name}]
self.assertEqual(data, [{"slice_id": slc.id, "slice_name": slc.slice_name}])
data = self.get_json_resp(
"/superset/warm_up_cache?table_name=energy_usage&db_name=main"
@ -430,13 +442,10 @@ class CoreTests(SupersetTestCase):
assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))
def test_kv(self):
self.logout()
self.login(username="admin")
try:
resp = self.client.post("/kv/store/", data=dict())
except Exception:
self.assertRaises(TypeError)
resp = self.client.get("/kv/10001/")
self.assertEqual(404, resp.status_code)
value = json.dumps({"data": "this is a test"})
resp = self.client.post("/kv/store/", data=dict(data=value))
@ -449,11 +458,6 @@ class CoreTests(SupersetTestCase):
self.assertEqual(resp.status_code, 200)
self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
try:
resp = self.client.get("/kv/10001/")
except Exception:
self.assertRaises(TypeError)
def test_gamma(self):
self.login(username="gamma")
assert "Charts" in self.get_resp("/chart/list/")
@ -808,6 +812,7 @@ class CoreTests(SupersetTestCase):
)
)
assert data == ["this_schema_is_allowed_too"]
self.delete_fake_db()
def test_select_star(self):
self.login(username="admin")
@ -950,7 +955,11 @@ class CoreTests(SupersetTestCase):
self.assertDictEqual(deserialized_payload, payload)
expand_data.assert_called_once()
@mock.patch.dict("superset._feature_flags", {"FOO": lambda x: 1}, clear=True)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"FOO": lambda x: 1},
clear=True,
)
def test_feature_flag_serialization(self):
"""
Functions in feature flags don't break bootstrap data serialization.

View File

@ -17,6 +17,7 @@
"""Unit tests for Superset"""
import json
import unittest
from random import random
from flask import escape
from sqlalchemy import func
@ -399,16 +400,19 @@ class DashboardTests(SupersetTestCase):
self.grant_public_access_to_table(table)
hidden_dash_slug = f"hidden_dash_{random()}"
published_dash_slug = f"published_dash_{random()}"
# Create a published and hidden dashboard and add them to the database
published_dash = models.Dashboard()
published_dash.dashboard_title = "Published Dashboard"
published_dash.slug = "published_dash"
published_dash.slug = published_dash_slug
published_dash.slices = [slice]
published_dash.published = True
hidden_dash = models.Dashboard()
hidden_dash.dashboard_title = "Hidden Dashboard"
hidden_dash.slug = "hidden_dash"
hidden_dash.slug = hidden_dash_slug
hidden_dash.slices = [slice]
hidden_dash.published = False
@ -417,22 +421,24 @@ class DashboardTests(SupersetTestCase):
db.session.commit()
resp = self.get_resp("/dashboard/list/")
self.assertNotIn("/superset/dashboard/hidden_dash/", resp)
self.assertIn("/superset/dashboard/published_dash/", resp)
self.assertNotIn(f"/superset/dashboard/{hidden_dash_slug}/", resp)
self.assertIn(f"/superset/dashboard/{published_dash_slug}/", resp)
def test_users_can_view_own_dashboard(self):
user = security_manager.find_user("gamma")
my_dash_slug = f"my_dash_{random()}"
not_my_dash_slug = f"not_my_dash_{random()}"
# Create one dashboard I own and another that I don't
dash = models.Dashboard()
dash.dashboard_title = "My Dashboard"
dash.slug = "my_dash"
dash.slug = my_dash_slug
dash.owners = [user]
dash.slices = []
hidden_dash = models.Dashboard()
hidden_dash.dashboard_title = "Not My Dashboard"
hidden_dash.slug = "not_my_dash"
hidden_dash.slug = not_my_dash_slug
hidden_dash.slices = []
hidden_dash.owners = []
@ -443,29 +449,27 @@ class DashboardTests(SupersetTestCase):
self.login(user.username)
resp = self.get_resp("/dashboard/list/")
self.assertIn("/superset/dashboard/my_dash/", resp)
self.assertNotIn("/superset/dashboard/not_my_dash/", resp)
self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp)
self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp)
def test_users_can_view_favorited_dashboards(self):
user = security_manager.find_user("gamma")
fav_dash_slug = f"my_favorite_dash_{random()}"
regular_dash_slug = f"regular_dash_{random()}"
favorite_dash = models.Dashboard()
favorite_dash.dashboard_title = "My Favorite Dashboard"
favorite_dash.slug = "my_favorite_dash"
favorite_dash.slug = fav_dash_slug
regular_dash = models.Dashboard()
regular_dash.dashboard_title = "A Plain Ol Dashboard"
regular_dash.slug = "regular_dash"
regular_dash.slug = regular_dash_slug
db.session.merge(favorite_dash)
db.session.merge(regular_dash)
db.session.commit()
dash = (
db.session.query(models.Dashboard)
.filter_by(slug="my_favorite_dash")
.first()
)
dash = db.session.query(models.Dashboard).filter_by(slug=fav_dash_slug).first()
favorites = models.FavStar()
favorites.obj_id = dash.id
@ -478,12 +482,12 @@ class DashboardTests(SupersetTestCase):
self.login(user.username)
resp = self.get_resp("/dashboard/list/")
self.assertIn("/superset/dashboard/my_favorite_dash/", resp)
self.assertIn(f"/superset/dashboard/{fav_dash_slug}/", resp)
def test_user_can_not_view_unpublished_dash(self):
admin_user = security_manager.find_user("admin")
gamma_user = security_manager.find_user("gamma")
slug = "admin_owned_unpublished_dash"
slug = f"admin_owned_unpublished_dash_{random()}"
# Create a dashboard owned by admin and unpublished
dash = models.Dashboard()

View File

@ -60,7 +60,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_simple_row_column(self):
presto_column = ("column_name", "row(nested_obj double)", "")
@ -68,7 +70,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
presto_column = ("column name", "row(nested_obj double)", "")
@ -76,7 +80,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "")
@ -87,7 +93,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_simple_array_column(self):
presto_column = ("column_name", "array(double)", "")
@ -95,7 +103,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_row_within_array_within_row_column(self):
presto_column = (
@ -112,7 +122,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.verify_presto_column(presto_column, expected_results)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_get_array_within_row_within_array_column(self):
presto_column = (
@ -147,7 +159,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.assertEqual(actual_result.name, expected_result["label"])
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_with_simple_structural_columns(self):
cols = [
@ -182,7 +196,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_with_complex_row_columns(self):
cols = [
@ -229,7 +245,9 @@ class PrestoTests(DbEngineSpecTestCase):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
@mock.patch.dict(
"superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True
"superset.extensions.feature_flag_manager._feature_flags",
{"PRESTO_EXPAND_DATA": True},
clear=True,
)
def test_presto_expand_data_with_complex_array_columns(self):
cols = [

View File

@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
import yaml
from tests.test_app import app
from superset import db
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
@ -41,15 +43,16 @@ class DictImportExportTests(SupersetTestCase):
@classmethod
def delete_imports(cls):
# Imported data clean up
session = db.session
for table in session.query(SqlaTable):
if DBREF in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
if DBREF in datasource.params_dict:
session.delete(datasource)
session.commit()
with app.app_context():
# Imported data clean up
session = db.session
for table in session.query(SqlaTable):
if DBREF in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
if DBREF in datasource.params_dict:
session.delete(datasource)
session.commit()
@classmethod
def setUpClass(cls):

View File

@ -47,7 +47,7 @@ def emplace(metrics_dict, metric_name, is_postagg=False):
# Unit tests that can be run without initializing base tests
class DruidFuncTestCase(unittest.TestCase):
class DruidFuncTestCase(SupersetTestCase):
@unittest.skipUnless(
SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
)

View File

@ -26,13 +26,14 @@ from unittest import mock
from superset import app
from superset.utils import core as utils
from tests.base_tests import SupersetTestCase
from .utils import read_fixture
send_email_test = mock.Mock()
class EmailSmtpTest(unittest.TestCase):
class EmailSmtpTest(SupersetTestCase):
def setUp(self):
app.config["smtp_ssl"] = False

View File

@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
from superset import is_feature_enabled
from tests.base_tests import SupersetTestCase
class FeatureFlagTests(SupersetTestCase):
@patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
{"FOO": True},
clear=True,
)
def test_existing_feature_flags(self):
self.assertTrue(is_feature_enabled("FOO"))
@patch.dict(
"superset.extensions.feature_flag_manager._feature_flags", {}, clear=True
)
def test_nonexistent_feature_flags(self):
self.assertFalse(is_feature_enabled("FOO"))
def test_feature_flags(self):
self.assertEqual(is_feature_enabled("foo"), "bar")
self.assertEqual(is_feature_enabled("super"), "set")

View File

@ -14,13 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import unittest
from flask import Flask, g
from flask import g
from sqlalchemy.orm.session import make_transient
from tests.test_app import app
from superset import db, security_manager
from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
@ -35,21 +37,22 @@ class ImportExportTests(SupersetTestCase):
@classmethod
def delete_imports(cls):
# Imported data clean up
session = db.session
for slc in session.query(models.Slice):
if "remote_id" in slc.params_dict:
session.delete(slc)
for dash in session.query(models.Dashboard):
if "remote_id" in dash.params_dict:
session.delete(dash)
for table in session.query(SqlaTable):
if "remote_id" in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
if "remote_id" in datasource.params_dict:
session.delete(datasource)
session.commit()
with app.app_context():
# Imported data clean up
session = db.session
for slc in session.query(models.Slice):
if "remote_id" in slc.params_dict:
session.delete(slc)
for dash in session.query(models.Dashboard):
if "remote_id" in dash.params_dict:
session.delete(dash)
for table in session.query(SqlaTable):
if "remote_id" in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
if "remote_id" in datasource.params_dict:
session.delete(datasource)
session.commit()
@classmethod
def setUpClass(cls):
@ -460,68 +463,64 @@ class ImportExportTests(SupersetTestCase):
)
def test_import_new_dashboard_slice_reset_ownership(self):
app = Flask("test_import_dashboard_slice_set_user")
with app.app_context():
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
# set another user as an owner of importing dashboard
dash_with_1_slice.created_by = admin_user
dash_with_1_slice.changed_by = admin_user
dash_with_1_slice.owners = [admin_user]
dash_with_1_slice = self._create_dashboard_for_import(id_=10200)
# set another user as an owner of importing dashboard
dash_with_1_slice.created_by = admin_user
dash_with_1_slice.changed_by = admin_user
dash_with_1_slice.owners = [admin_user]
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
def test_import_override_dashboard_slice_reset_ownership(self):
app = Flask("test_import_dashboard_slice_set_user")
with app.app_context():
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
admin_user = security_manager.find_user(username="admin")
self.assertTrue(admin_user)
gamma_user = security_manager.find_user(username="gamma")
self.assertTrue(gamma_user)
g.user = gamma_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
# re-import with another user shouldn't change the permissions
g.user = admin_user
# re-import with another user shouldn't change the permissions
g.user = admin_user
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
dash_with_1_slice = self._create_dashboard_for_import(id_=10300)
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice)
imported_dash = self.get_dash(imported_dash_id)
self.assertEqual(imported_dash.created_by, gamma_user)
self.assertEqual(imported_dash.changed_by, gamma_user)
self.assertEqual(imported_dash.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
imported_slc = imported_dash.slices[0]
self.assertEqual(imported_slc.created_by, gamma_user)
self.assertEqual(imported_slc.changed_by, gamma_user)
self.assertEqual(imported_slc.owners, [gamma_user])
def _create_dashboard_for_import(self, id_=10100):
slc = self.create_slice("health_slc" + str(id_), id=id_ + 1)

View File

@ -14,27 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset import examples
from superset.cli import load_test_users_run
from .base_tests import SupersetTestCase
class SupersetDataFrameTestCase(SupersetTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.examples = None
def setUp(self) -> None:
# Late importing here as we need an app context to be pushed...
from superset import examples
self.examples = examples
def test_load_css_templates(self):
examples.load_css_templates()
self.examples.load_css_templates()
def test_load_energy(self):
examples.load_energy()
self.examples.load_energy()
def test_load_world_bank_health_n_pop(self):
examples.load_world_bank_health_n_pop()
self.examples.load_world_bank_health_n_pop()
def test_load_birth_names(self):
examples.load_birth_names()
self.examples.load_birth_names()
def test_load_test_users_run(self):
from superset.cli import load_test_users_run
load_test_users_run()
def test_load_unicode_test_data(self):
examples.load_unicode_test_data()
self.examples.load_unicode_test_data()

View File

@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
# isort:skip_file
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, PropertyMock
from flask_babel import gettext as __
from selenium.common.exceptions import WebDriverException
from superset import app, db
from tests.test_app import app
from superset import db
from superset.models.core import Dashboard, Slice
from superset.models.schedules import (
DashboardEmailSchedule,
@ -35,11 +36,12 @@ from superset.tasks.schedules import (
deliver_slice,
next_schedules,
)
from tests.base_tests import SupersetTestCase
from .utils import read_fixture
class SchedulesTestCase(unittest.TestCase):
class SchedulesTestCase(SupersetTestCase):
RECIPIENTS = "recipient1@superset.com, recipient2@superset.com"
BCC = "bcc@superset.com"
@ -47,41 +49,45 @@ class SchedulesTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.common_data = dict(
active=True,
crontab="* * * * *",
recipients=cls.RECIPIENTS,
deliver_as_group=True,
delivery_type=EmailDeliveryType.inline,
)
with app.app_context():
cls.common_data = dict(
active=True,
crontab="* * * * *",
recipients=cls.RECIPIENTS,
deliver_as_group=True,
delivery_type=EmailDeliveryType.inline,
)
# Pick up a random slice and dashboard
slce = db.session.query(Slice).all()[0]
dashboard = db.session.query(Dashboard).all()[0]
# Pick up a random slice and dashboard
slce = db.session.query(Slice).all()[0]
dashboard = db.session.query(Dashboard).all()[0]
dashboard_schedule = DashboardEmailSchedule(**cls.common_data)
dashboard_schedule.dashboard_id = dashboard.id
dashboard_schedule.user_id = 1
db.session.add(dashboard_schedule)
dashboard_schedule = DashboardEmailSchedule(**cls.common_data)
dashboard_schedule.dashboard_id = dashboard.id
dashboard_schedule.user_id = 1
db.session.add(dashboard_schedule)
slice_schedule = SliceEmailSchedule(**cls.common_data)
slice_schedule.slice_id = slce.id
slice_schedule.user_id = 1
slice_schedule.email_format = SliceEmailReportFormat.data
slice_schedule = SliceEmailSchedule(**cls.common_data)
slice_schedule.slice_id = slce.id
slice_schedule.user_id = 1
slice_schedule.email_format = SliceEmailReportFormat.data
db.session.add(slice_schedule)
db.session.commit()
db.session.add(slice_schedule)
db.session.commit()
cls.slice_schedule = slice_schedule.id
cls.dashboard_schedule = dashboard_schedule.id
cls.slice_schedule = slice_schedule.id
cls.dashboard_schedule = dashboard_schedule.id
@classmethod
def tearDownClass(cls):
db.session.query(SliceEmailSchedule).filter_by(id=cls.slice_schedule).delete()
db.session.query(DashboardEmailSchedule).filter_by(
id=cls.dashboard_schedule
).delete()
db.session.commit()
with app.app_context():
db.session.query(SliceEmailSchedule).filter_by(
id=cls.slice_schedule
).delete()
db.session.query(DashboardEmailSchedule).filter_by(
id=cls.dashboard_schedule
).delete()
db.session.commit()
def test_crontab_scheduler(self):
crontab = "* * * * *"

View File

@ -310,6 +310,7 @@ class RolePermissionTests(SupersetTestCase):
["Superset", "welcome"],
["SecurityApi", "login"],
["SecurityApi", "refresh"],
["SupersetIndexView", "index"],
]
unsecured_views = []
for view_class in appbuilder.baseviews:

View File

@ -60,7 +60,11 @@ class SqlValidatorEndpointTests(SupersetTestCase):
self.assertIn("no SQL validator is configured", resp["error"])
@patch("superset.views.core.get_validator_by_name")
@patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True)
@patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
PRESTO_TEST_FEATURE_FLAGS,
clear=True,
)
def test_validate_sql_endpoint_mocked(self, get_validator_by_name):
"""Assert that, with a mocked validator, annotations make it back out
from the validate_sql_json endpoint as a list of json dictionaries"""
@ -87,7 +91,11 @@ class SqlValidatorEndpointTests(SupersetTestCase):
self.assertIn("expected,", resp[0]["message"])
@patch("superset.views.core.get_validator_by_name")
@patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True)
@patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
PRESTO_TEST_FEATURE_FLAGS,
clear=True,
)
def test_validate_sql_endpoint_failure(self, get_validator_by_name):
"""Assert that validate_sql_json errors out when the selected validator
raises an unexpected exception"""

View File

@ -17,6 +17,7 @@
"""Unit tests for Sql Lab"""
import json
from datetime import datetime, timedelta
from random import random
import prison
@ -294,19 +295,19 @@ class SqlLabTests(SupersetTestCase):
examples_dbid = get_example_database().id
payload = {
"chartType": "dist_bar",
"datasourceName": "test_viz_flow_table",
"datasourceName": f"test_viz_flow_table_{random()}",
"schema": "superset",
"columns": [
{
"is_date": False,
"type": "STRING",
"name": "viz_type",
"name": f"viz_type_{random()}",
"is_dim": True,
},
{
"is_date": False,
"type": "OBJECT",
"name": "ccount",
"name": f"ccount_{random()}",
"is_dim": True,
"agg": "sum",
},
@ -421,3 +422,4 @@ class SqlLabTests(SupersetTestCase):
{"examples", "fake_db_100"},
{r.get("database_name") for r in self.get_json_resp(url)["result"]},
)
self.delete_fake_db()

24
tests/test_app.py Normal file
View File

@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Here is where we create the app which ends up being shared across all tests. A future
optimization will be to create a separate app instance for each test class.
"""
from superset.app import create_app
app = create_app()

View File

@ -28,6 +28,7 @@ from sqlalchemy.exc import ArgumentError
from superset import app, db, security_manager
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.utils.cache_manager import CacheManager
from superset.utils.core import (
base_json_conv,
convert_legacy_filters_into_adhoc,
@ -45,7 +46,6 @@ from superset.utils.core import (
parse_human_timedelta,
parse_js_uri_path_item,
parse_past_timedelta,
setup_cache,
split,
TimeRangeEndpoint,
validate_json,
@ -53,6 +53,7 @@ from superset.utils.core import (
zlib_decompress,
)
from superset.views.utils import get_time_range_endpoints
from tests.base_tests import SupersetTestCase
def mock_parse_human_datetime(s):
@ -93,7 +94,7 @@ def mock_to_adhoc(filt, expressionType="SIMPLE", clause="where"):
return result
class UtilsTestCase(unittest.TestCase):
class UtilsTestCase(SupersetTestCase):
def test_json_int_dttm_ser(self):
dttm = datetime(2020, 1, 1)
ts = 1577836800000.0
@ -809,12 +810,12 @@ class UtilsTestCase(unittest.TestCase):
def test_setup_cache_no_config(self):
app = Flask(__name__)
cache_config = None
self.assertIsNone(setup_cache(app, cache_config))
self.assertIsNone(CacheManager._setup_cache(app, cache_config))
def test_setup_cache_null_config(self):
app = Flask(__name__)
cache_config = {"CACHE_TYPE": "null"}
self.assertIsNone(setup_cache(app, cache_config))
self.assertIsNone(CacheManager._setup_cache(app, cache_config))
def test_setup_cache_standard_config(self):
app = Flask(__name__)
@ -824,7 +825,7 @@ class UtilsTestCase(unittest.TestCase):
"CACHE_KEY_PREFIX": "superset_results",
"CACHE_REDIS_URL": "redis://localhost:6379/0",
}
assert isinstance(setup_cache(app, cache_config), Cache) is True
assert isinstance(CacheManager._setup_cache(app, cache_config), Cache) is True
def test_setup_cache_custom_function(self):
app = Flask(__name__)
@ -833,7 +834,9 @@ class UtilsTestCase(unittest.TestCase):
def init_cache(app):
return CustomCache(app, {})
assert isinstance(setup_cache(app, init_cache), CustomCache) is True
assert (
isinstance(CacheManager._setup_cache(app, init_cache), CustomCache) is True
)
def test_get_stacktrace(self):
with app.app_context():
@ -879,6 +882,8 @@ class UtilsTestCase(unittest.TestCase):
get_or_create_db("test_db", "sqlite:///changed.db")
database = db.session.query(Database).filter_by(database_name="test_db").one()
self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db")
db.session.delete(database)
db.session.commit()
def test_get_or_create_db_invalid_uri(self):
with self.assertRaises(ArgumentError):

View File

@ -19,7 +19,7 @@ commands =
{toxinidir}/superset/bin/superset db upgrade
{toxinidir}/superset/bin/superset init
nosetests tests/load_examples_test.py
nosetests -e load_examples_test {posargs}
nosetests -e load_examples_test tests {posargs}
deps =
-rrequirements.txt
-rrequirements-dev.txt