feat: add name, description and non null tables to RLS (#20432)

* feat: add name, description and non null tables to RLS

* add validation

* add and fix tests

* fix sqlite migration

* improve default value for name
This commit is contained in:
Daniel Vaz Gaspar 2022-06-20 13:52:05 +01:00 committed by GitHub
parent 8b0bee5e8b
commit 60eb1094a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 215 additions and 8 deletions

View File

@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
__tablename__ = "row_level_security_filters" __tablename__ = "row_level_security_filters"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(255), unique=True, nullable=False)
description = Column(Text)
filter_type = Column( filter_type = Column(
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType]) Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
) )
@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
tables = relationship( tables = relationship(
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters" SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
) )
clause = Column(Text, nullable=False) clause = Column(Text, nullable=False)

View File

@ -26,7 +26,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access from flask_appbuilder.security.decorators import has_access
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import Regexp from wtforms.validators import DataRequired, Regexp
from superset import app, db from superset import app, db
from superset.connectors.base.views import DatasourceModelView from superset.connectors.base.views import DatasourceModelView
@ -47,6 +47,19 @@ from superset.views.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods
"""
Select required flag on the input field will not work well on Chrome
Console error:
An invalid form control with name='tables' is not focusable.
This makes a simple override to the DataRequired to be used specifically with
select fields
"""
field_flags = ()
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
datamodel = SQLAInterface(models.TableColumn) datamodel = SQLAInterface(models.TableColumn)
# TODO TODO, review need for this on related_views # TODO TODO, review need for this on related_views
@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
edit_title = _("Edit Row level security filter") edit_title = _("Edit Row level security filter")
list_columns = [ list_columns = [
"name",
"filter_type",
"tables",
"roles",
"clause",
"creator",
"modified",
]
order_columns = ["name", "filter_type", "clause", "modified"]
edit_columns = [
"name",
"description",
"filter_type", "filter_type",
"tables", "tables",
"roles", "roles",
"group_key", "group_key",
"clause", "clause",
"creator",
"modified",
] ]
order_columns = ["filter_type", "group_key", "clause", "modified"]
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
show_columns = edit_columns show_columns = edit_columns
search_columns = ("filter_type", "tables", "roles", "group_key", "clause") search_columns = (
"name",
"description",
"filter_type",
"tables",
"roles",
"group_key",
"clause",
)
add_columns = edit_columns add_columns = edit_columns
base_order = ("changed_on", "desc") base_order = ("changed_on", "desc")
description_columns = { description_columns = {
"name": _("Choose a unique name"),
"description": _("Optionally add a detailed description"),
"filter_type": _( "filter_type": _(
"Regular filters add where clauses to queries if a user belongs to a " "Regular filters add where clauses to queries if a user belongs to a "
"role referenced in the filter. Base filters apply filters to all queries " "role referenced in the filter. Base filters apply filters to all queries "
@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
), ),
} }
label_columns = { label_columns = {
"name": _("Name"),
"description": _("Description"),
"tables": _("Tables"), "tables": _("Tables"),
"roles": _("Roles"), "roles": _("Roles"),
"clause": _("Clause"), "clause": _("Clause"),
"creator": _("Creator"), "creator": _("Creator"),
"modified": _("Modified"), "modified": _("Modified"),
} }
validators_columns = {"tables": [SelectDataRequired()]}
if app.config["RLS_FORM_QUERY_REL_FIELDS"]: if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"] add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
edit_form_query_rel_fields = add_form_query_rel_fields edit_form_query_rel_fields = add_form_query_rel_fields

View File

@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_unique_name_desc_rls
Revision ID: f3afaf1f11f0
Revises: e786798587de
Create Date: 2022-06-19 16:17:23.318618
"""
# revision identifiers, used by Alembic.
revision = "f3afaf1f11f0"
down_revision = "e786798587de"
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
Base = declarative_base()
class RowLevelSecurityFilter(Base):
__tablename__ = "row_level_security_filters"
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(255), unique=True, nullable=False)
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
bind = op.get_bind()
session = Session(bind=bind)
op.add_column(
"row_level_security_filters", sa.Column("name", sa.String(length=255))
)
op.add_column(
"row_level_security_filters", sa.Column("description", sa.Text(), nullable=True)
)
# Set initial default names make sure we can have unique non null values
all_rls = session.query(RowLevelSecurityFilter).all()
for rls in all_rls:
rls.name = f"rls-{rls.id}"
session.commit()
# Now it's safe so set non-null and unique
# add unique constraint
with op.batch_alter_table("row_level_security_filters") as batch_op:
# batch mode is required for sqlite
batch_op.alter_column(
"name",
existing_type=sa.String(255),
nullable=False,
)
batch_op.create_unique_constraint("uq_rls_name", ["name"])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique")
op.drop_column("row_level_security_filters", "description")
op.drop_column("row_level_security_filters", "name")
# ### end Alembic commands ###

View File

@ -25,7 +25,6 @@ from flask import g
from superset import db, security_manager from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import ( from superset.security.guest_token import (
GuestTokenRlsRule,
GuestTokenResourceType, GuestTokenResourceType,
GuestUser, GuestUser,
) )
@ -82,6 +81,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test) # Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
self.rls_entry1 = RowLevelSecurityFilter() self.rls_entry1 = RowLevelSecurityFilter()
self.rls_entry1.name = "rls_entry1"
self.rls_entry1.tables.extend( self.rls_entry1.tables.extend(
session.query(SqlaTable) session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
@ -96,6 +96,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B) # Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
self.rls_entry2 = RowLevelSecurityFilter() self.rls_entry2 = RowLevelSecurityFilter()
self.rls_entry2.name = "rls_entry2"
self.rls_entry2.tables.extend( self.rls_entry2.tables.extend(
session.query(SqlaTable) session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"])) .filter(SqlaTable.table_name.in_(["birth_names"]))
@ -109,6 +110,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q) # Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
self.rls_entry3 = RowLevelSecurityFilter() self.rls_entry3 = RowLevelSecurityFilter()
self.rls_entry3.name = "rls_entry3"
self.rls_entry3.tables.extend( self.rls_entry3.tables.extend(
session.query(SqlaTable) session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"])) .filter(SqlaTable.table_name.in_(["birth_names"]))
@ -122,6 +124,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create Base RowLevelSecurityFilter (birth_names boys) # Create Base RowLevelSecurityFilter (birth_names boys)
self.rls_entry4 = RowLevelSecurityFilter() self.rls_entry4 = RowLevelSecurityFilter()
self.rls_entry4.name = "rls_entry4"
self.rls_entry4.tables.extend( self.rls_entry4.tables.extend(
session.query(SqlaTable) session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"])) .filter(SqlaTable.table_name.in_(["birth_names"]))
@ -146,6 +149,94 @@ class TestRowLevelSecurity(SupersetTestCase):
session.delete(self.get_user("NoRlsRoleUser")) session.delete(self.get_user("NoRlsRoleUser"))
session.commit() session.commit()
@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():
dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
db.session.add(dataset)
db.session.flush()
db.session.commit()
yield dataset
# rollback changes (assuming cascade delete)
db.session.delete(dataset)
db.session.commit()
def _get_test_dataset(self):
return (
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
).one_or_none()
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_success(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
assert rls1 is not None
# Revert data changes
db.session.delete(rls1)
db.session.commit()
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_name_unique(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls_entry1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "Already exists." in data
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
self.login(username="admin")
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "This field is required." in data
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self): def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha") g.user = self.get_user(username="alpha")

View File

@ -186,6 +186,7 @@ def test_sql_lab_insert_rls(
# now with RLS # now with RLS
rls = RowLevelSecurityFilter( rls = RowLevelSecurityFilter(
name="sqllab_rls1",
filter_type=RowLevelSecurityFilterType.REGULAR, filter_type=RowLevelSecurityFilterType.REGULAR,
tables=[SqlaTable(database_id=1, schema=None, table_name="t")], tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
roles=[admin.roles[0]], roles=[admin.roles[0]],