diff --git a/superset/app.py b/superset/app.py index 2d89ffa988..48883a9264 100644 --- a/superset/app.py +++ b/superset/app.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) def create_app() -> Flask: - app = Flask(__name__) + app = SupersetApp(__name__) try: # Allow user to override our config completely @@ -42,3 +42,7 @@ def create_app() -> Flask: except Exception as ex: logger.exception("Failed to create app") raise ex + + +class SupersetApp(Flask): + pass diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 0171b170e2..2654590d25 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import logging import os -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, TYPE_CHECKING import wtforms_json from flask import Flask, redirect @@ -48,15 +49,18 @@ from superset.typing import FlaskResponse from superset.utils.core import pessimistic_connection_handling from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value +if TYPE_CHECKING: + from superset.app import SupersetApp + logger = logging.getLogger(__name__) # pylint: disable=R0904 class SupersetAppInitializer: - def __init__(self, app: Flask) -> None: + def __init__(self, app: SupersetApp) -> None: super().__init__() - self.flask_app = app + self.superset_app = app self.config = app.config self.manifest: Dict[Any, Any] = {} @@ -77,7 +81,7 @@ class SupersetAppInitializer: def configure_celery(self) -> None: celery_app.config_from_object(self.config["CELERY_CONFIG"]) celery_app.set_default() - flask_app = self.flask_app + superset_app = self.superset_app # Here, we want to ensure that every call into Celery task has an app context # setup properly @@ -89,7 +93,7 @@ class SupersetAppInitializer: # Grab each call into the task and set up an app context def __call__(self, *args: Any, **kwargs: Any) -> Any: - with flask_app.app_context(): # type: ignore + with superset_app.app_context(): # type: ignore return task_base.__call__(self, *args, **kwargs) celery_app.Task = AppContextTask @@ -538,7 +542,7 @@ class SupersetAppInitializer: # after initialization flask_app_mutator = self.config["FLASK_APP_MUTATOR"] if flask_app_mutator: - flask_app_mutator(self.flask_app) + flask_app_mutator(self.superset_app) self.init_views() @@ -563,17 +567,17 @@ class SupersetAppInitializer: self.configure_middlewares() self.configure_cache() - with self.flask_app.app_context(): # type: ignore + with self.superset_app.app_context(): # type: ignore self.init_app_in_ctx() self.post_init() def configure_auth_provider(self) -> None: - machine_auth_provider_factory.init_app(self.flask_app) + machine_auth_provider_factory.init_app(self.superset_app) def setup_event_logger(self) -> None: _event_logger["event_logger"] = get_event_logger_from_cfg_value( - self.flask_app.config.get("EVENT_LOGGER", DBEventLogger()) + self.superset_app.config.get("EVENT_LOGGER", DBEventLogger()) ) def configure_data_sources(self) -> None: @@ -583,11 +587,11 @@ class SupersetAppInitializer: ConnectorRegistry.register_sources(module_datasource_map) def configure_cache(self) -> None: - cache_manager.init_app(self.flask_app) - results_backend_manager.init_app(self.flask_app) + cache_manager.init_app(self.superset_app) + results_backend_manager.init_app(self.superset_app) def configure_feature_flags(self) -> None: - feature_flag_manager.init_app(self.flask_app) + feature_flag_manager.init_app(self.superset_app) def configure_fab(self) -> None: if self.config["SILENCE_FAB"]: @@ -604,7 +608,7 @@ class SupersetAppInitializer: appbuilder.indexview = SupersetIndexView appbuilder.base_template = "superset/base.html" appbuilder.security_manager_class = custom_sm - appbuilder.init_app(self.flask_app, db.session) + appbuilder.init_app(self.superset_app, db.session) def configure_url_map_converters(self) -> None: # @@ -616,20 +620,20 @@ class SupersetAppInitializer: RegexConverter, ) - self.flask_app.url_map.converters["regex"] = RegexConverter - self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter + self.superset_app.url_map.converters["regex"] = RegexConverter + self.superset_app.url_map.converters["object_type"] = ObjectTypeConverter def configure_middlewares(self) -> None: if self.config["ENABLE_CORS"]: from flask_cors import CORS - CORS(self.flask_app, **self.config["CORS_OPTIONS"]) + CORS(self.superset_app, **self.config["CORS_OPTIONS"]) if self.config["ENABLE_PROXY_FIX"]: from werkzeug.middleware.proxy_fix import ProxyFix - self.flask_app.wsgi_app = ProxyFix( # type: ignore - self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"] + self.superset_app.wsgi_app = ProxyFix( # type: ignore + self.superset_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"] ) if self.config["ENABLE_CHUNK_ENCODING"]: @@ -647,8 +651,8 @@ class SupersetAppInitializer: environ["wsgi.input_terminated"] = True return self.app(environ, start_response) - self.flask_app.wsgi_app = ChunkedEncodingFix( # type: ignore - self.flask_app.wsgi_app # type: ignore + self.superset_app.wsgi_app = ChunkedEncodingFix( # type: ignore + self.superset_app.wsgi_app # type: ignore ) if self.config["UPLOAD_FOLDER"]: @@ -658,53 +662,53 @@ class SupersetAppInitializer: pass for middleware in self.config["ADDITIONAL_MIDDLEWARE"]: - self.flask_app.wsgi_app = middleware( # type: ignore - self.flask_app.wsgi_app + self.superset_app.wsgi_app = middleware( # type: ignore + self.superset_app.wsgi_app ) # Flask-Compress - Compress(self.flask_app) + Compress(self.superset_app) if self.config["TALISMAN_ENABLED"]: - talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"]) + talisman.init_app(self.superset_app, **self.config["TALISMAN_CONFIG"]) def configure_logging(self) -> None: self.config["LOGGING_CONFIGURATOR"].configure_logging( - self.config, self.flask_app.debug + self.config, self.superset_app.debug ) def configure_db_encrypt(self) -> None: - encrypted_field_factory.init_app(self.flask_app) + encrypted_field_factory.init_app(self.superset_app) def setup_db(self) -> None: - db.init_app(self.flask_app) + db.init_app(self.superset_app) - with self.flask_app.app_context(): # type: ignore + with self.superset_app.app_context(): # type: ignore pessimistic_connection_handling(db.engine) - migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations") + migrate.init_app(self.superset_app, db=db, directory=APP_DIR + "/migrations") def configure_wtf(self) -> None: if self.config["WTF_CSRF_ENABLED"]: - csrf.init_app(self.flask_app) + csrf.init_app(self.superset_app) csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"] for ex in csrf_exempt_list: csrf.exempt(ex) def configure_async_queries(self) -> None: if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"): - async_query_manager.init_app(self.flask_app) + async_query_manager.init_app(self.superset_app) def register_blueprints(self) -> None: for bp in self.config["BLUEPRINTS"]: try: logger.info("Registering blueprint: %s", bp.name) - self.flask_app.register_blueprint(bp) + self.superset_app.register_blueprint(bp) except Exception: # pylint: disable=broad-except logger.exception("blueprint registration failed") def setup_bundle_manifest(self) -> None: - manifest_processor.init_app(self.flask_app) + manifest_processor.init_app(self.superset_app) class SupersetIndexView(IndexView):