[database] Fix, tables API endpoint (#9144)

This commit is contained in:
Daniel Vaz Gaspar 2020-02-20 10:15:22 +00:00 committed by GitHub
parent c1750af54a
commit e55fe43ca6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 37 deletions

View File

@ -27,17 +27,7 @@ import msgpack
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
import simplejson as json import simplejson as json
from flask import ( from flask import abort, flash, g, Markup, redirect, render_template, request, Response
abort,
flash,
g,
Markup,
redirect,
render_template,
request,
Response,
url_for,
)
from flask_appbuilder import expose from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access, has_access_api from flask_appbuilder.security.decorators import has_access, has_access_api
@ -46,7 +36,6 @@ from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import and_, Integer, or_, select from sqlalchemy import and_, Integer, or_, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from werkzeug.routing import BaseConverter
from werkzeug.urls import Href from werkzeug.urls import Href
import superset.models.core as models import superset.models.core as models
@ -88,11 +77,10 @@ from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export from superset.utils import core as utils, dashboard_import_export
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
from superset.utils.decorators import etag_cache, stats_timing from superset.utils.decorators import etag_cache, stats_timing
from superset.views.chart import views as chart_views from superset.views.database.filters import DatabaseFilter
from .base import ( from .base import (
api, api,
BaseFilter,
BaseSupersetView, BaseSupersetView,
check_ownership, check_ownership,
common_bootstrap_payload, common_bootstrap_payload,
@ -107,8 +95,6 @@ from .base import (
json_success, json_success,
SupersetModelView, SupersetModelView,
) )
from .dashboard import views as dash_views
from .database import views as in_views
from .utils import ( from .utils import (
apply_display_max_row_limit, apply_display_max_row_limit,
bootstrap_user_data, bootstrap_user_data,
@ -1068,21 +1054,30 @@ class Superset(BaseSupersetView):
@api @api
@has_access_api @has_access_api
@expose("/tables/<db_id>/<schema>/<substr>/") @expose("/tables/<int:db_id>/<schema>/<substr>/")
@expose("/tables/<db_id>/<schema>/<substr>/<force_refresh>/") @expose("/tables/<int:db_id>/<schema>/<substr>/<force_refresh>/")
def tables(self, db_id, schema, substr, force_refresh="false"): def tables(
self, db_id: int, schema: str, substr: str, force_refresh: str = "false"
):
"""Endpoint to fetch the list of tables for given database""" """Endpoint to fetch the list of tables for given database"""
db_id = int(db_id) # Guarantees database filtering by security access
force_refresh = force_refresh.lower() == "true" query = db.session.query(models.Database)
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) query = DatabaseFilter("id", SQLAInterface(models.Database, db.session)).apply(
substr = utils.parse_js_uri_path_item(substr, eval_undefined=True) query, None
database = db.session.query(models.Database).filter_by(id=db_id).one() )
database = query.filter_by(id=db_id).one_or_none()
if not database:
return json_error_response("Not found", 404)
if schema: force_refresh_parsed = force_refresh.lower() == "true"
schema_parsed = utils.parse_js_uri_path_item(schema, eval_undefined=True)
substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True)
if schema_parsed:
tables = ( tables = (
database.get_all_table_names_in_schema( database.get_all_table_names_in_schema(
schema=schema, schema=schema_parsed,
force=force_refresh, force=force_refresh_parsed,
cache=database.table_cache_enabled, cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout, cache_timeout=database.table_cache_timeout,
) )
@ -1090,8 +1085,8 @@ class Superset(BaseSupersetView):
) )
views = ( views = (
database.get_all_view_names_in_schema( database.get_all_view_names_in_schema(
schema=schema, schema=schema_parsed,
force=force_refresh, force=force_refresh_parsed,
cache=database.table_cache_enabled, cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout, cache_timeout=database.table_cache_timeout,
) )
@ -1105,20 +1100,22 @@ class Superset(BaseSupersetView):
cache=True, force=False, cache_timeout=24 * 60 * 60 cache=True, force=False, cache_timeout=24 * 60 * 60
) )
tables = security_manager.get_datasources_accessible_by_user( tables = security_manager.get_datasources_accessible_by_user(
database, tables, schema database, tables, schema_parsed
) )
views = security_manager.get_datasources_accessible_by_user( views = security_manager.get_datasources_accessible_by_user(
database, views, schema database, views, schema_parsed
) )
def get_datasource_label(ds_name: utils.DatasourceName) -> str: def get_datasource_label(ds_name: utils.DatasourceName) -> str:
return ds_name.table if schema else f"{ds_name.schema}.{ds_name.table}" return (
ds_name.table if schema_parsed else f"{ds_name.schema}.{ds_name.table}"
)
if substr: if substr_parsed:
tables = [tn for tn in tables if substr in get_datasource_label(tn)] tables = [tn for tn in tables if substr_parsed in get_datasource_label(tn)]
views = [vn for vn in views if substr in get_datasource_label(vn)] views = [vn for vn in views if substr_parsed in get_datasource_label(vn)]
if not schema and database.default_schemas: if not schema_parsed and database.default_schemas:
user_schema = g.user.email.split("@")[0] user_schema = g.user.email.split("@")[0]
valid_schemas = set(database.default_schemas + [user_schema]) valid_schemas = set(database.default_schemas + [user_schema])
@ -1129,7 +1126,7 @@ class Superset(BaseSupersetView):
total_items = len(tables) + len(views) total_items = len(tables) + len(views)
max_tables = len(tables) max_tables = len(tables)
max_views = len(views) max_views = len(views)
if total_items and substr: if total_items and substr_parsed:
max_tables = max_items * len(tables) // total_items max_tables = max_items * len(tables) // total_items
max_views = max_items * len(views) // total_items max_views = max_items * len(views) // total_items

View File

@ -143,6 +143,8 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
max_page_size = -1 max_page_size = -1
validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator} validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator}
openapi_spec_tag = "Database"
@expose( @expose(
"/<int:pk>/table/<string:table_name>/<string:schema_name>/", methods=["GET"] "/<int:pk>/table/<string:table_name>/<string:schema_name>/", methods=["GET"]
) )

View File

@ -40,6 +40,13 @@ FAKE_DB_NAME = "fake_db_100"
class SupersetTestCase(TestCase): class SupersetTestCase(TestCase):
default_schema_backend_map = {
"sqlite": "main",
"mysql": "superset",
"postgresql": "public",
}
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SupersetTestCase, self).__init__(*args, **kwargs) super(SupersetTestCase, self).__init__(*args, **kwargs)
self.maxDiff = None self.maxDiff = None

View File

@ -165,6 +165,43 @@ class CoreTests(SupersetTestCase):
# the new cache_key should be different due to updated datasource # the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new) self.assertNotEqual(cache_key_original, cache_key_new)
def test_get_superset_tables_not_allowed(self):
example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
self.login(username="gamma")
uri = f"superset/tables/{example_db.id}/{schema_name}/undefined/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_superset_tables_substr(self):
example_db = utils.get_example_database()
self.login(username="admin")
schema_name = self.default_schema_backend_map[example_db.backend]
uri = f"superset/tables/{example_db.id}/{schema_name}/ab_role/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
expeted_response = {
"options": [
{
"label": "ab_role",
"schema": schema_name,
"title": "ab_role",
"type": "table",
"value": "ab_role",
}
],
"tableLength": 1,
}
self.assertEqual(response, expeted_response)
def test_get_superset_tables_not_found(self):
self.login(username="admin")
uri = f"superset/tables/invalid/public/undefined/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_api_v1_query_endpoint(self): def test_api_v1_query_endpoint(self):
self.login(username="admin") self.login(username="admin")
qc_dict = self._get_query_context_dict() qc_dict = self._get_query_context_dict()