mirror of https://github.com/apache/superset.git
feat: add name, description and non null tables to RLS (#20432)
* feat: add name, description and non null tables to RLS * add validation * add and fix tests * fix sqlite migration * improve default value for name
This commit is contained in:
parent
8b0bee5e8b
commit
60eb1094a4
|
@ -2482,6 +2482,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
||||||
|
|
||||||
__tablename__ = "row_level_security_filters"
|
__tablename__ = "row_level_security_filters"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
|
name = Column(String(255), unique=True, nullable=False)
|
||||||
|
description = Column(Text)
|
||||||
filter_type = Column(
|
filter_type = Column(
|
||||||
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
|
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
|
||||||
)
|
)
|
||||||
|
@ -2494,5 +2496,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
||||||
tables = relationship(
|
tables = relationship(
|
||||||
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
|
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
|
||||||
)
|
)
|
||||||
|
|
||||||
clause = Column(Text, nullable=False)
|
clause = Column(Text, nullable=False)
|
||||||
|
|
|
@ -26,7 +26,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||||
from flask_appbuilder.security.decorators import has_access
|
from flask_appbuilder.security.decorators import has_access
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
from wtforms.ext.sqlalchemy.fields import QuerySelectField
|
from wtforms.ext.sqlalchemy.fields import QuerySelectField
|
||||||
from wtforms.validators import Regexp
|
from wtforms.validators import DataRequired, Regexp
|
||||||
|
|
||||||
from superset import app, db
|
from superset import app, db
|
||||||
from superset.connectors.base.views import DatasourceModelView
|
from superset.connectors.base.views import DatasourceModelView
|
||||||
|
@ -47,6 +47,19 @@ from superset.views.base import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-methods
|
||||||
|
"""
|
||||||
|
Select required flag on the input field will not work well on Chrome
|
||||||
|
Console error:
|
||||||
|
An invalid form control with name='tables' is not focusable.
|
||||||
|
|
||||||
|
This makes a simple override to the DataRequired to be used specifically with
|
||||||
|
select fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_flags = ()
|
||||||
|
|
||||||
|
|
||||||
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
|
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
|
||||||
datamodel = SQLAInterface(models.TableColumn)
|
datamodel = SQLAInterface(models.TableColumn)
|
||||||
# TODO TODO, review need for this on related_views
|
# TODO TODO, review need for this on related_views
|
||||||
|
@ -272,21 +285,39 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
|
||||||
edit_title = _("Edit Row level security filter")
|
edit_title = _("Edit Row level security filter")
|
||||||
|
|
||||||
list_columns = [
|
list_columns = [
|
||||||
|
"name",
|
||||||
|
"filter_type",
|
||||||
|
"tables",
|
||||||
|
"roles",
|
||||||
|
"clause",
|
||||||
|
"creator",
|
||||||
|
"modified",
|
||||||
|
]
|
||||||
|
order_columns = ["name", "filter_type", "clause", "modified"]
|
||||||
|
edit_columns = [
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
"filter_type",
|
"filter_type",
|
||||||
"tables",
|
"tables",
|
||||||
"roles",
|
"roles",
|
||||||
"group_key",
|
"group_key",
|
||||||
"clause",
|
"clause",
|
||||||
"creator",
|
|
||||||
"modified",
|
|
||||||
]
|
]
|
||||||
order_columns = ["filter_type", "group_key", "clause", "modified"]
|
|
||||||
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
|
|
||||||
show_columns = edit_columns
|
show_columns = edit_columns
|
||||||
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
|
search_columns = (
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"filter_type",
|
||||||
|
"tables",
|
||||||
|
"roles",
|
||||||
|
"group_key",
|
||||||
|
"clause",
|
||||||
|
)
|
||||||
add_columns = edit_columns
|
add_columns = edit_columns
|
||||||
base_order = ("changed_on", "desc")
|
base_order = ("changed_on", "desc")
|
||||||
description_columns = {
|
description_columns = {
|
||||||
|
"name": _("Choose a unique name"),
|
||||||
|
"description": _("Optionally add a detailed description"),
|
||||||
"filter_type": _(
|
"filter_type": _(
|
||||||
"Regular filters add where clauses to queries if a user belongs to a "
|
"Regular filters add where clauses to queries if a user belongs to a "
|
||||||
"role referenced in the filter. Base filters apply filters to all queries "
|
"role referenced in the filter. Base filters apply filters to all queries "
|
||||||
|
@ -319,12 +350,16 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
label_columns = {
|
label_columns = {
|
||||||
|
"name": _("Name"),
|
||||||
|
"description": _("Description"),
|
||||||
"tables": _("Tables"),
|
"tables": _("Tables"),
|
||||||
"roles": _("Roles"),
|
"roles": _("Roles"),
|
||||||
"clause": _("Clause"),
|
"clause": _("Clause"),
|
||||||
"creator": _("Creator"),
|
"creator": _("Creator"),
|
||||||
"modified": _("Modified"),
|
"modified": _("Modified"),
|
||||||
}
|
}
|
||||||
|
validators_columns = {"tables": [SelectDataRequired()]}
|
||||||
|
|
||||||
if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
|
if app.config["RLS_FORM_QUERY_REL_FIELDS"]:
|
||||||
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
|
add_form_query_rel_fields = app.config["RLS_FORM_QUERY_REL_FIELDS"]
|
||||||
edit_form_query_rel_fields = add_form_query_rel_fields
|
edit_form_query_rel_fields = add_form_query_rel_fields
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""add_unique_name_desc_rls
|
||||||
|
|
||||||
|
Revision ID: f3afaf1f11f0
|
||||||
|
Revises: e786798587de
|
||||||
|
Create Date: 2022-06-19 16:17:23.318618
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "f3afaf1f11f0"
|
||||||
|
down_revision = "e786798587de"
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class RowLevelSecurityFilter(Base):
|
||||||
|
__tablename__ = "row_level_security_filters"
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String(255), unique=True, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
bind = op.get_bind()
|
||||||
|
session = Session(bind=bind)
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"row_level_security_filters", sa.Column("name", sa.String(length=255))
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"row_level_security_filters", sa.Column("description", sa.Text(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set initial default names make sure we can have unique non null values
|
||||||
|
all_rls = session.query(RowLevelSecurityFilter).all()
|
||||||
|
for rls in all_rls:
|
||||||
|
rls.name = f"rls-{rls.id}"
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Now it's safe so set non-null and unique
|
||||||
|
# add unique constraint
|
||||||
|
with op.batch_alter_table("row_level_security_filters") as batch_op:
|
||||||
|
# batch mode is required for sqlite
|
||||||
|
batch_op.alter_column(
|
||||||
|
"name",
|
||||||
|
existing_type=sa.String(255),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
batch_op.create_unique_constraint("uq_rls_name", ["name"])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint("uq_rls_name", "row_level_security_filters", type_="unique")
|
||||||
|
op.drop_column("row_level_security_filters", "description")
|
||||||
|
op.drop_column("row_level_security_filters", "name")
|
||||||
|
# ### end Alembic commands ###
|
|
@ -25,7 +25,6 @@ from flask import g
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
|
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
|
||||||
from superset.security.guest_token import (
|
from superset.security.guest_token import (
|
||||||
GuestTokenRlsRule,
|
|
||||||
GuestTokenResourceType,
|
GuestTokenResourceType,
|
||||||
GuestUser,
|
GuestUser,
|
||||||
)
|
)
|
||||||
|
@ -82,6 +81,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
|
|
||||||
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
|
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
|
||||||
self.rls_entry1 = RowLevelSecurityFilter()
|
self.rls_entry1 = RowLevelSecurityFilter()
|
||||||
|
self.rls_entry1.name = "rls_entry1"
|
||||||
self.rls_entry1.tables.extend(
|
self.rls_entry1.tables.extend(
|
||||||
session.query(SqlaTable)
|
session.query(SqlaTable)
|
||||||
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
|
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
|
||||||
|
@ -96,6 +96,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
|
|
||||||
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
|
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
|
||||||
self.rls_entry2 = RowLevelSecurityFilter()
|
self.rls_entry2 = RowLevelSecurityFilter()
|
||||||
|
self.rls_entry2.name = "rls_entry2"
|
||||||
self.rls_entry2.tables.extend(
|
self.rls_entry2.tables.extend(
|
||||||
session.query(SqlaTable)
|
session.query(SqlaTable)
|
||||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||||
|
@ -109,6 +110,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
|
|
||||||
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
|
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
|
||||||
self.rls_entry3 = RowLevelSecurityFilter()
|
self.rls_entry3 = RowLevelSecurityFilter()
|
||||||
|
self.rls_entry3.name = "rls_entry3"
|
||||||
self.rls_entry3.tables.extend(
|
self.rls_entry3.tables.extend(
|
||||||
session.query(SqlaTable)
|
session.query(SqlaTable)
|
||||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||||
|
@ -122,6 +124,7 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
|
|
||||||
# Create Base RowLevelSecurityFilter (birth_names boys)
|
# Create Base RowLevelSecurityFilter (birth_names boys)
|
||||||
self.rls_entry4 = RowLevelSecurityFilter()
|
self.rls_entry4 = RowLevelSecurityFilter()
|
||||||
|
self.rls_entry4.name = "rls_entry4"
|
||||||
self.rls_entry4.tables.extend(
|
self.rls_entry4.tables.extend(
|
||||||
session.query(SqlaTable)
|
session.query(SqlaTable)
|
||||||
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
.filter(SqlaTable.table_name.in_(["birth_names"]))
|
||||||
|
@ -146,6 +149,94 @@ class TestRowLevelSecurity(SupersetTestCase):
|
||||||
session.delete(self.get_user("NoRlsRoleUser"))
|
session.delete(self.get_user("NoRlsRoleUser"))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def create_dataset(self):
|
||||||
|
with self.create_app().app_context():
|
||||||
|
|
||||||
|
dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
|
||||||
|
db.session.add(dataset)
|
||||||
|
db.session.flush()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
yield dataset
|
||||||
|
|
||||||
|
# rollback changes (assuming cascade delete)
|
||||||
|
db.session.delete(dataset)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _get_test_dataset(self):
|
||||||
|
return (
|
||||||
|
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
|
||||||
|
).one_or_none()
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("create_dataset")
|
||||||
|
def test_model_view_rls_add_success(self):
|
||||||
|
self.login(username="admin")
|
||||||
|
test_dataset = self._get_test_dataset()
|
||||||
|
rv = self.client.post(
|
||||||
|
"/rowlevelsecurityfiltersmodelview/add",
|
||||||
|
data=dict(
|
||||||
|
name="rls1",
|
||||||
|
description="Some description",
|
||||||
|
filter_type="Regular",
|
||||||
|
tables=[test_dataset.id],
|
||||||
|
roles=[security_manager.find_role("Alpha").id],
|
||||||
|
group_key="group_key_1",
|
||||||
|
clause="client_id=1",
|
||||||
|
),
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
rls1 = (
|
||||||
|
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
|
||||||
|
).one_or_none()
|
||||||
|
assert rls1 is not None
|
||||||
|
|
||||||
|
# Revert data changes
|
||||||
|
db.session.delete(rls1)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("create_dataset")
|
||||||
|
def test_model_view_rls_add_name_unique(self):
|
||||||
|
self.login(username="admin")
|
||||||
|
test_dataset = self._get_test_dataset()
|
||||||
|
rv = self.client.post(
|
||||||
|
"/rowlevelsecurityfiltersmodelview/add",
|
||||||
|
data=dict(
|
||||||
|
name="rls_entry1",
|
||||||
|
description="Some description",
|
||||||
|
filter_type="Regular",
|
||||||
|
tables=[test_dataset.id],
|
||||||
|
roles=[security_manager.find_role("Alpha").id],
|
||||||
|
group_key="group_key_1",
|
||||||
|
clause="client_id=1",
|
||||||
|
),
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
data = rv.data.decode("utf-8")
|
||||||
|
assert "Already exists." in data
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("create_dataset")
|
||||||
|
def test_model_view_rls_add_tables_required(self):
|
||||||
|
self.login(username="admin")
|
||||||
|
rv = self.client.post(
|
||||||
|
"/rowlevelsecurityfiltersmodelview/add",
|
||||||
|
data=dict(
|
||||||
|
name="rls1",
|
||||||
|
description="Some description",
|
||||||
|
filter_type="Regular",
|
||||||
|
tables=[],
|
||||||
|
roles=[security_manager.find_role("Alpha").id],
|
||||||
|
group_key="group_key_1",
|
||||||
|
clause="client_id=1",
|
||||||
|
),
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
data = rv.data.decode("utf-8")
|
||||||
|
assert "This field is required." in data
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||||
def test_rls_filter_alters_energy_query(self):
|
def test_rls_filter_alters_energy_query(self):
|
||||||
g.user = self.get_user(username="alpha")
|
g.user = self.get_user(username="alpha")
|
||||||
|
|
|
@ -186,6 +186,7 @@ def test_sql_lab_insert_rls(
|
||||||
|
|
||||||
# now with RLS
|
# now with RLS
|
||||||
rls = RowLevelSecurityFilter(
|
rls = RowLevelSecurityFilter(
|
||||||
|
name="sqllab_rls1",
|
||||||
filter_type=RowLevelSecurityFilterType.REGULAR,
|
filter_type=RowLevelSecurityFilterType.REGULAR,
|
||||||
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
|
tables=[SqlaTable(database_id=1, schema=None, table_name="t")],
|
||||||
roles=[admin.roles[0]],
|
roles=[admin.roles[0]],
|
||||||
|
|
Loading…
Reference in New Issue