diff --git a/superset/views/core.py b/superset/views/core.py index f94a906d17..1c4640c71c 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -27,17 +27,7 @@ import msgpack import pandas as pd import pyarrow as pa import simplejson as json -from flask import ( - abort, - flash, - g, - Markup, - redirect, - render_template, - request, - Response, - url_for, -) +from flask import abort, flash, g, Markup, redirect, render_template, request, Response from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api @@ -46,7 +36,6 @@ from flask_babel import gettext as __, lazy_gettext as _ from sqlalchemy import and_, Integer, or_, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.session import Session -from werkzeug.routing import BaseConverter from werkzeug.urls import Href import superset.models.core as models @@ -88,11 +77,10 @@ from superset.sql_validators import get_validator_by_name from superset.utils import core as utils, dashboard_import_export from superset.utils.dates import now_as_float from superset.utils.decorators import etag_cache, stats_timing -from superset.views.chart import views as chart_views +from superset.views.database.filters import DatabaseFilter from .base import ( api, - BaseFilter, BaseSupersetView, check_ownership, common_bootstrap_payload, @@ -107,8 +95,6 @@ from .base import ( json_success, SupersetModelView, ) -from .dashboard import views as dash_views -from .database import views as in_views from .utils import ( apply_display_max_row_limit, bootstrap_user_data, @@ -1068,21 +1054,30 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/tables////") - @expose("/tables/////") - def tables(self, db_id, schema, substr, force_refresh="false"): + @expose("/tables////") + @expose("/tables/////") + def tables( + self, db_id: int, schema: str, substr: str, force_refresh: str = "false" + ): """Endpoint to fetch the list of tables for given database""" - db_id = int(db_id) - force_refresh = force_refresh.lower() == "true" - schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) - substr = utils.parse_js_uri_path_item(substr, eval_undefined=True) - database = db.session.query(models.Database).filter_by(id=db_id).one() + # Guarantees database filtering by security access + query = db.session.query(models.Database) + query = DatabaseFilter("id", SQLAInterface(models.Database, db.session)).apply( + query, None + ) + database = query.filter_by(id=db_id).one_or_none() + if not database: + return json_error_response("Not found", 404) - if schema: + force_refresh_parsed = force_refresh.lower() == "true" + schema_parsed = utils.parse_js_uri_path_item(schema, eval_undefined=True) + substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True) + + if schema_parsed: tables = ( database.get_all_table_names_in_schema( - schema=schema, - force=force_refresh, + schema=schema_parsed, + force=force_refresh_parsed, cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) @@ -1090,8 +1085,8 @@ class Superset(BaseSupersetView): ) views = ( database.get_all_view_names_in_schema( - schema=schema, - force=force_refresh, + schema=schema_parsed, + force=force_refresh_parsed, cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) @@ -1105,20 +1100,22 @@ class Superset(BaseSupersetView): cache=True, force=False, cache_timeout=24 * 60 * 60 ) tables = security_manager.get_datasources_accessible_by_user( - database, tables, schema + database, tables, schema_parsed ) views = security_manager.get_datasources_accessible_by_user( - database, views, schema + database, views, schema_parsed ) def get_datasource_label(ds_name: utils.DatasourceName) -> str: - return ds_name.table if schema else f"{ds_name.schema}.{ds_name.table}" + return ( + ds_name.table if schema_parsed else f"{ds_name.schema}.{ds_name.table}" + ) - if substr: - tables = [tn for tn in tables if substr in get_datasource_label(tn)] - views = [vn for vn in views if substr in get_datasource_label(vn)] + if substr_parsed: + tables = [tn for tn in tables if substr_parsed in get_datasource_label(tn)] + views = [vn for vn in views if substr_parsed in get_datasource_label(vn)] - if not schema and database.default_schemas: + if not schema_parsed and database.default_schemas: user_schema = g.user.email.split("@")[0] valid_schemas = set(database.default_schemas + [user_schema]) @@ -1129,7 +1126,7 @@ class Superset(BaseSupersetView): total_items = len(tables) + len(views) max_tables = len(tables) max_views = len(views) - if total_items and substr: + if total_items and substr_parsed: max_tables = max_items * len(tables) // total_items max_views = max_items * len(views) // total_items diff --git a/superset/views/database/api.py b/superset/views/database/api.py index 2eb5ad88cf..d4ee3c995e 100644 --- a/superset/views/database/api.py +++ b/superset/views/database/api.py @@ -143,6 +143,8 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): max_page_size = -1 validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator} + openapi_spec_tag = "Database" + @expose( "//table///", methods=["GET"] ) diff --git a/tests/base_tests.py b/tests/base_tests.py index 94f0cebc73..5728dd8868 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -40,6 +40,13 @@ FAKE_DB_NAME = "fake_db_100" class SupersetTestCase(TestCase): + + default_schema_backend_map = { + "sqlite": "main", + "mysql": "superset", + "postgresql": "public", + } + def __init__(self, *args, **kwargs): super(SupersetTestCase, self).__init__(*args, **kwargs) self.maxDiff = None diff --git a/tests/core_tests.py b/tests/core_tests.py index 19a5ccd381..cf93d7acab 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -165,6 +165,43 @@ class CoreTests(SupersetTestCase): # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new) + def test_get_superset_tables_not_allowed(self): + example_db = utils.get_example_database() + schema_name = self.default_schema_backend_map[example_db.backend] + self.login(username="gamma") + uri = f"superset/tables/{example_db.id}/{schema_name}/undefined/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_superset_tables_substr(self): + example_db = utils.get_example_database() + self.login(username="admin") + schema_name = self.default_schema_backend_map[example_db.backend] + uri = f"superset/tables/{example_db.id}/{schema_name}/ab_role/" + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + expeted_response = { + "options": [ + { + "label": "ab_role", + "schema": schema_name, + "title": "ab_role", + "type": "table", + "value": "ab_role", + } + ], + "tableLength": 1, + } + self.assertEqual(response, expeted_response) + + def test_get_superset_tables_not_found(self): + self.login(username="admin") + uri = f"superset/tables/invalid/public/undefined/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + def test_api_v1_query_endpoint(self): self.login(username="admin") qc_dict = self._get_query_context_dict()