mirror of https://github.com/apache/superset.git
feat: add name, description and non null tables to RLS (#20432)
* feat: add name, description and non null tables to RLS * add validation * add and fix tests * fix sqlite migration * improve default value for name
This commit is contained in:
parent
8b0bee5e8b
commit
60eb1094a4
|
@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
|||
|
||||
__tablename__ = "row_level_security_filters"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(255), unique=True, nullable=False)
|
||||
description = Column(Text)
|
||||
filter_type = Column(
|
||||
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
|
||||
)
|
||||
|
@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
|||
tables = relationship(
|
||||
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
|
||||
)
|
||||
|
||||
clause = Column(Text, nullable=False)
|
||||
|
|
|
@ -26,7 +26,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
|
|||
from flask_appbuilder.security.decorators import has_access
|
||||
from flask_babel import lazy_gettext as _
|
||||
from wtforms.ext.sqlalchemy.fields import QuerySelectField
|
||||
from wtforms.validators import Regexp
|
||||
from wtforms.validators import DataRequired, Regexp
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.base.views import DatasourceModelView
|
||||
|
@ -47,6 +47,19 @@ from superset.views.base import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Select required flag on the input field will not work well on Chrome
|
||||
Console error:
|
||||
An invalid form control with name='tables' is not focusable.
|
||||
|
||||
This makes a simple override to the DataRequired to be used specifically with
|
||||
select fields
|
||||
"""
|
||||
|
||||
field_flags = ()
|
||||
|
||||
|
||||
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
|
||||
datamodel = SQLAInterface(models.TableColumn)
|
||||
# TODO TODO, review need for this on related_views
|
||||
|
@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
|
|||
edit_title = _("Edit Row level security filter")
|
||||
|
||||
list_columns = [
|
||||
"name",
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"clause",
|
||||
"creator",
|
||||
"modified",
|
||||
]
|
||||
order_columns = ["name", "filter_type", "clause", "modified"]
|
||||
edit_columns = [
|
||||
"name",
|
||||
"description",
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"group_key",
|
||||
"clause",
|
||||
"creator",
|
||||
"modified",
|
||||
]
|
||||
order_columns = ["filter_type", "group_key", "clause", "modified"]
|
||||
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
|
||||
show_columns = edit_columns
|
||||
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
|
||||
search_columns = (
|
||||
"name",
|
||||
"description",
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"group_key",
|
||||
"clause",
|
||||
)
|
||||
add_columns = edit_columns
|
||||
base_order = ("changed_on", "desc")
|
||||
description_columns = {
|
||||
"name": _("Choose a unique name"),
|
||||
"description": _("Optionally add a detailed description"),
|
||||
"filter_type": _(
|
||||
"Regular filters add where clauses to queries if a user belongs to a "
|
||||
"role referenced in the filter. Base filters apply filters to all queries "
|
||||
|
@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
|
|||
),
|
||||
}
|
||||
label_columns = {
|
||||
"name": _("Name"),
|
||||
"description": _("Description"),
|
||||
"tables": _("Tables"),
|
||||
"roles": _("Roles"),
|
||||
"clause": _("Clause"),
|
||||
"creator": _("Creator"),
|
||||
"modified": _("Modified"),
|
||||
}
|
||||
validators_columns = {"tables": [SelectDataRequired()]}
|
||||
|
||||
if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
|
||||
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
|
||||
edit_form_query_rel_fields = add_form_query_rel_fields
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add_unique_name_desc_rls
|
||||
|
||||
Revision ID: f3afaf1f11f0
|
||||
Revises: e786798587de
|
||||
Create Date: 2022-06-19 16:17:23.318618
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f3afaf1f11f0"
|
||||
down_revision = "e786798587de"
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class RowLevelSecurityFilter(Base):
|
||||
__tablename__ = "row_level_security_filters"
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(255), unique=True, nullable=False)
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
op.add_column(
|
||||
"row_level_security_filters", sa.Column("name", sa.String(length=255))
|
||||
)
|
||||
op.add_column(
|
||||
"row_level_security_filters", sa.Column("description", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Set initial default names make sure we can have unique non null values
|
||||
all_rls = session.query(RowLevelSecurityFilter).all()
|
||||
for rls in all_rls:
|
||||
rls.name = f"rls-{rls.id}"
|
||||
session.commit()
|
||||
|
||||
# Now it's safe so set non-null and unique
|
||||
# add unique constraint
|
||||
with op.batch_alter_table("row_level_security_filters") as batch_op:
|
||||
# batch mode is required for sqlite
|
||||
batch_op.alter_column(
|
||||
"name",
|
||||
existing_type=sa.String(255),
|
||||
nullable=False,
|
||||
)
|
||||
batch_op.create_unique_constraint("uq_rls_name", ["name"])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique")
|
||||
op.drop_column("row_level_security_filters", "description")
|
||||
op.drop_column("row_level_security_filters", "name")
|
||||
# ### end Alembic commands ###
|
|
@ -25,7 +25,6 @@ from flask import g
|
|||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
|
||||
from superset.security.guest_token import (
|
||||
GuestTokenRlsRule,
|
||||
GuestTokenResourceType,
|
||||
GuestUser,
|
||||
)
|
||||
|
@ -82,6 +81,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
|
||||
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
|
||||
self.rls_entry1 = RowLevelSecurityFilter()
|
||||
self.rls_entry1.name = "rls_entry1"
|
||||
self.rls_entry1.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
|
||||
|
@ -96,6 +96,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
|
||||
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
|
||||
self.rls_entry2 = RowLevelSecurityFilter()
|
||||
self.rls_entry2.name = "rls_entry2"
|
||||
self.rls_entry2.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
|
@ -109,6 +110,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
|
||||
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
|
||||
self.rls_entry3 = RowLevelSecurityFilter()
|
||||
self.rls_entry3.name = "rls_entry3"
|
||||
self.rls_entry3.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
|
@ -122,6 +124,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
|
||||
# Create Base RowLevelSecurityFilter (birth_names boys)
|
||||
self.rls_entry4 = RowLevelSecurityFilter()
|
||||
self.rls_entry4.name = "rls_entry4"
|
||||
self.rls_entry4.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||
|
@ -146,6 +149,94 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
session.delete(self.get_user("NoRlsRoleUser"))
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture()
|
||||
def create_dataset(self):
|
||||
with self.create_app().app_context():
|
||||
|
||||
dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
yield dataset
|
||||
|
||||
# rollback changes (assuming cascade delete)
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
def _get_test_dataset(self):
|
||||
return (
|
||||
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
|
||||
).one_or_none()
|
||||
|
||||
@pytest.mark.usefixtures("create_dataset")
|
||||
def test_model_view_rls_add_success(self):
|
||||
self.login(username="admin")
|
||||
test_dataset = self._get_test_dataset()
|
||||
rv = self.client.post(
|
||||
"/rowlevelsecurityfiltersmodelview/add",
|
||||
data=dict(
|
||||
name="rls1",
|
||||
description="Some description",
|
||||
filter_type="Regular",
|
||||
tables=[test_dataset.id],
|
||||
roles=[security_manager.find_role("Alpha").id],
|
||||
group_key="group_key_1",
|
||||
clause="client_id=1",
|
||||
),
|
||||
follow_redirects=True,
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
rls1 = (
|
||||
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
|
||||
).one_or_none()
|
||||
assert rls1 is not None
|
||||
|
||||
# Revert data changes
|
||||
db.session.delete(rls1)
|
||||
db.session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("create_dataset")
|
||||
def test_model_view_rls_add_name_unique(self):
|
||||
self.login(username="admin")
|
||||
test_dataset = self._get_test_dataset()
|
||||
rv = self.client.post(
|
||||
"/rowlevelsecurityfiltersmodelview/add",
|
||||
data=dict(
|
||||
name="rls_entry1",
|
||||
description="Some description",
|
||||
filter_type="Regular",
|
||||
tables=[test_dataset.id],
|
||||
roles=[security_manager.find_role("Alpha").id],
|
||||
group_key="group_key_1",
|
||||
clause="client_id=1",
|
||||
),
|
||||
follow_redirects=True,
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = rv.data.decode("utf-8")
|
||||
assert "Already exists." in data
|
||||
|
||||
@pytest.mark.usefixtures("create_dataset")
|
||||
def test_model_view_rls_add_tables_required(self):
|
||||
self.login(username="admin")
|
||||
rv = self.client.post(
|
||||
"/rowlevelsecurityfiltersmodelview/add",
|
||||
data=dict(
|
||||
name="rls1",
|
||||
description="Some description",
|
||||
filter_type="Regular",
|
||||
tables=[],
|
||||
roles=[security_manager.find_role("Alpha").id],
|
||||
group_key="group_key_1",
|
||||
clause="client_id=1",
|
||||
),
|
||||
follow_redirects=True,
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = rv.data.decode("utf-8")
|
||||
assert "This field is required." in data
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_rls_filter_alters_energy_query(self):
|
||||
g.user = self.get_user(username="alpha")
|
||||
|
|
|
@ -186,6 +186,7 @@ def test_sql_lab_insert_rls(
|
|||
|
||||
# now with RLS
|
||||
rls = RowLevelSecurityFilter(
|
||||
name="sqllab_rls1",
|
||||
filter_type=RowLevelSecurityFilterType.REGULAR,
|
||||
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
|
||||
roles=[admin.roles[0]],
|
||||
|
|
Loading…
Reference in New Issue