mirror of
https://github.com/apache/superset.git
synced 2024-09-17 11:09:47 -04:00
576 lines
20 KiB
Python
576 lines
20 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
# pylint: disable=import-outside-toplevel, unused-argument, unused-import, invalid-name
|
|
|
|
import copy
|
|
import json
|
|
import re
|
|
import uuid
|
|
from typing import Any
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from flask import current_app
|
|
from pytest_mock import MockFixture
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
from superset.datasets.commands.exceptions import (
|
|
DatasetForbiddenDataURI,
|
|
ImportFailedError,
|
|
)
|
|
from superset.datasets.commands.importers.v1.utils import validate_data_uri
|
|
|
|
|
|
def test_import_dataset(mocker: MockFixture, session: Session) -> None:
|
|
"""
|
|
Test importing a dataset.
|
|
"""
|
|
from superset import security_manager
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.models.core import Database
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
dataset_uuid = uuid.uuid4()
|
|
config = {
|
|
"table_name": "my_table",
|
|
"main_dttm_col": "ds",
|
|
"description": "This is the description",
|
|
"default_endpoint": None,
|
|
"offset": -8,
|
|
"cache_timeout": 3600,
|
|
"schema": "my_schema",
|
|
"sql": None,
|
|
"params": {
|
|
"remote_id": 64,
|
|
"database_name": "examples",
|
|
"import_time": 1606677834,
|
|
},
|
|
"template_params": {
|
|
"answer": "42",
|
|
},
|
|
"filter_select_enabled": True,
|
|
"fetch_values_predicate": "foo IN (1, 2)",
|
|
"extra": {"warning_markdown": "*WARNING*"},
|
|
"uuid": dataset_uuid,
|
|
"metrics": [
|
|
{
|
|
"metric_name": "cnt",
|
|
"verbose_name": None,
|
|
"metric_type": None,
|
|
"expression": "COUNT(*)",
|
|
"description": None,
|
|
"d3format": None,
|
|
"extra": {"warning_markdown": None},
|
|
"warning_text": None,
|
|
}
|
|
],
|
|
"columns": [
|
|
{
|
|
"column_name": "profit",
|
|
"verbose_name": None,
|
|
"is_dttm": None,
|
|
"is_active": None,
|
|
"type": "INTEGER",
|
|
"groupby": None,
|
|
"filterable": None,
|
|
"expression": "revenue-expenses",
|
|
"description": None,
|
|
"python_date_format": None,
|
|
"extra": {
|
|
"certified_by": "User",
|
|
},
|
|
}
|
|
],
|
|
"database_uuid": database.uuid,
|
|
"database_id": database.id,
|
|
}
|
|
|
|
sqla_table = import_dataset(session, config)
|
|
assert sqla_table.table_name == "my_table"
|
|
assert sqla_table.main_dttm_col == "ds"
|
|
assert sqla_table.description == "This is the description"
|
|
assert sqla_table.default_endpoint is None
|
|
assert sqla_table.offset == -8
|
|
assert sqla_table.cache_timeout == 3600
|
|
assert sqla_table.schema == "my_schema"
|
|
assert sqla_table.sql is None
|
|
assert sqla_table.params == json.dumps(
|
|
{"remote_id": 64, "database_name": "examples", "import_time": 1606677834}
|
|
)
|
|
assert sqla_table.template_params == json.dumps({"answer": "42"})
|
|
assert sqla_table.filter_select_enabled is True
|
|
assert sqla_table.fetch_values_predicate == "foo IN (1, 2)"
|
|
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
|
|
assert sqla_table.uuid == dataset_uuid
|
|
assert len(sqla_table.metrics) == 1
|
|
assert sqla_table.metrics[0].metric_name == "cnt"
|
|
assert sqla_table.metrics[0].verbose_name is None
|
|
assert sqla_table.metrics[0].metric_type is None
|
|
assert sqla_table.metrics[0].expression == "COUNT(*)"
|
|
assert sqla_table.metrics[0].description is None
|
|
assert sqla_table.metrics[0].d3format is None
|
|
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
|
|
assert sqla_table.metrics[0].warning_text is None
|
|
assert len(sqla_table.columns) == 1
|
|
assert sqla_table.columns[0].column_name == "profit"
|
|
assert sqla_table.columns[0].verbose_name is None
|
|
assert sqla_table.columns[0].is_dttm is False
|
|
assert sqla_table.columns[0].is_active is True
|
|
assert sqla_table.columns[0].type == "INTEGER"
|
|
assert sqla_table.columns[0].groupby is True
|
|
assert sqla_table.columns[0].filterable is True
|
|
assert sqla_table.columns[0].expression == "revenue-expenses"
|
|
assert sqla_table.columns[0].description is None
|
|
assert sqla_table.columns[0].python_date_format is None
|
|
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
|
|
assert sqla_table.database.uuid == database.uuid
|
|
assert sqla_table.database.id == database.id
|
|
|
|
|
|
def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session) -> None:
|
|
"""
|
|
Test importing a dataset with a column that already exists.
|
|
"""
|
|
from superset import security_manager
|
|
from superset.columns.models import Column as NewColumn
|
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.models.core import Database
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
dataset_uuid = uuid.uuid4()
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
dataset = SqlaTable(
|
|
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
|
|
)
|
|
column = TableColumn(column_name="existing_column")
|
|
session.add(dataset)
|
|
session.add(column)
|
|
session.flush()
|
|
|
|
config = {
|
|
"table_name": dataset.table_name,
|
|
"main_dttm_col": "ds",
|
|
"description": "This is the description",
|
|
"default_endpoint": None,
|
|
"offset": -8,
|
|
"cache_timeout": 3600,
|
|
"schema": "my_schema",
|
|
"sql": None,
|
|
"params": {
|
|
"remote_id": 64,
|
|
"database_name": "examples",
|
|
"import_time": 1606677834,
|
|
},
|
|
"template_params": {
|
|
"answer": "42",
|
|
},
|
|
"filter_select_enabled": True,
|
|
"fetch_values_predicate": "foo IN (1, 2)",
|
|
"extra": {"warning_markdown": "*WARNING*"},
|
|
"uuid": dataset_uuid,
|
|
"metrics": [
|
|
{
|
|
"metric_name": "cnt",
|
|
"verbose_name": None,
|
|
"metric_type": None,
|
|
"expression": "COUNT(*)",
|
|
"description": None,
|
|
"d3format": None,
|
|
"extra": {"warning_markdown": None},
|
|
"warning_text": None,
|
|
}
|
|
],
|
|
"columns": [
|
|
{
|
|
"column_name": column.column_name,
|
|
"verbose_name": None,
|
|
"is_dttm": None,
|
|
"is_active": None,
|
|
"type": "INTEGER",
|
|
"groupby": None,
|
|
"filterable": None,
|
|
"expression": "revenue-expenses",
|
|
"description": None,
|
|
"python_date_format": None,
|
|
"extra": {
|
|
"certified_by": "User",
|
|
},
|
|
}
|
|
],
|
|
"database_uuid": database.uuid,
|
|
"database_id": database.id,
|
|
}
|
|
|
|
sqla_table = import_dataset(session, config, overwrite=True)
|
|
assert sqla_table.table_name == dataset.table_name
|
|
assert sqla_table.main_dttm_col == "ds"
|
|
assert sqla_table.description == "This is the description"
|
|
assert sqla_table.default_endpoint is None
|
|
assert sqla_table.offset == -8
|
|
assert sqla_table.cache_timeout == 3600
|
|
assert sqla_table.schema == "my_schema"
|
|
assert sqla_table.sql is None
|
|
assert sqla_table.params == json.dumps(
|
|
{"remote_id": 64, "database_name": "examples", "import_time": 1606677834}
|
|
)
|
|
assert sqla_table.template_params == json.dumps({"answer": "42"})
|
|
assert sqla_table.filter_select_enabled is True
|
|
assert sqla_table.fetch_values_predicate == "foo IN (1, 2)"
|
|
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
|
|
assert sqla_table.uuid == dataset_uuid
|
|
assert len(sqla_table.metrics) == 1
|
|
assert sqla_table.metrics[0].metric_name == "cnt"
|
|
assert sqla_table.metrics[0].verbose_name is None
|
|
assert sqla_table.metrics[0].metric_type is None
|
|
assert sqla_table.metrics[0].expression == "COUNT(*)"
|
|
assert sqla_table.metrics[0].description is None
|
|
assert sqla_table.metrics[0].d3format is None
|
|
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
|
|
assert sqla_table.metrics[0].warning_text is None
|
|
assert len(sqla_table.columns) == 1
|
|
assert sqla_table.columns[0].column_name == column.column_name
|
|
assert sqla_table.columns[0].verbose_name is None
|
|
assert sqla_table.columns[0].is_dttm is False
|
|
assert sqla_table.columns[0].is_active is True
|
|
assert sqla_table.columns[0].type == "INTEGER"
|
|
assert sqla_table.columns[0].groupby is True
|
|
assert sqla_table.columns[0].filterable is True
|
|
assert sqla_table.columns[0].expression == "revenue-expenses"
|
|
assert sqla_table.columns[0].description is None
|
|
assert sqla_table.columns[0].python_date_format is None
|
|
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
|
|
assert sqla_table.database.uuid == database.uuid
|
|
assert sqla_table.database.id == database.id
|
|
|
|
|
|
def test_import_column_extra_is_string(mocker: MockFixture, session: Session) -> None:
|
|
"""
|
|
Test importing a dataset when the column extra is a string.
|
|
"""
|
|
from superset import security_manager
|
|
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.datasets.schemas import ImportV1DatasetSchema
|
|
from superset.models.core import Database
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
dataset_uuid = uuid.uuid4()
|
|
yaml_config: dict[str, Any] = {
|
|
"version": "1.0.0",
|
|
"table_name": "my_table",
|
|
"main_dttm_col": "ds",
|
|
"description": "This is the description",
|
|
"default_endpoint": None,
|
|
"offset": -8,
|
|
"cache_timeout": 3600,
|
|
"schema": "my_schema",
|
|
"sql": None,
|
|
"params": {
|
|
"remote_id": 64,
|
|
"database_name": "examples",
|
|
"import_time": 1606677834,
|
|
},
|
|
"template_params": {
|
|
"answer": "42",
|
|
},
|
|
"filter_select_enabled": True,
|
|
"fetch_values_predicate": "foo IN (1, 2)",
|
|
"extra": '{"warning_markdown": "*WARNING*"}',
|
|
"uuid": dataset_uuid,
|
|
"metrics": [
|
|
{
|
|
"metric_name": "cnt",
|
|
"verbose_name": None,
|
|
"metric_type": None,
|
|
"expression": "COUNT(*)",
|
|
"description": None,
|
|
"d3format": None,
|
|
"extra": '{"warning_markdown": null}',
|
|
"warning_text": None,
|
|
}
|
|
],
|
|
"columns": [
|
|
{
|
|
"column_name": "profit",
|
|
"verbose_name": None,
|
|
"is_dttm": False,
|
|
"is_active": True,
|
|
"type": "INTEGER",
|
|
"groupby": False,
|
|
"filterable": False,
|
|
"expression": "revenue-expenses",
|
|
"description": None,
|
|
"python_date_format": None,
|
|
"extra": '{"certified_by": "User"}',
|
|
}
|
|
],
|
|
"database_uuid": database.uuid,
|
|
}
|
|
|
|
# the Marshmallow schema should convert strings to objects
|
|
schema = ImportV1DatasetSchema()
|
|
dataset_config = schema.load(yaml_config)
|
|
dataset_config["database_id"] = database.id
|
|
sqla_table = import_dataset(session, dataset_config)
|
|
|
|
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
|
|
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
|
|
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
|
|
|
|
|
|
def test_import_dataset_extra_empty_string(
|
|
mocker: MockFixture, session: Session
|
|
) -> None:
|
|
"""
|
|
Test importing a dataset when the extra field is an empty string.
|
|
"""
|
|
from superset import security_manager
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.datasets.schemas import ImportV1DatasetSchema
|
|
from superset.models.core import Database
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
dataset_uuid = uuid.uuid4()
|
|
yaml_config: dict[str, Any] = {
|
|
"version": "1.0.0",
|
|
"table_name": "my_table",
|
|
"main_dttm_col": "ds",
|
|
"schema": "my_schema",
|
|
"sql": None,
|
|
"params": {
|
|
"remote_id": 64,
|
|
"database_name": "examples",
|
|
"import_time": 1606677834,
|
|
},
|
|
"extra": " ",
|
|
"uuid": dataset_uuid,
|
|
"metrics": [
|
|
{
|
|
"metric_name": "cnt",
|
|
"expression": "COUNT(*)",
|
|
}
|
|
],
|
|
"columns": [
|
|
{
|
|
"column_name": "profit",
|
|
"is_dttm": False,
|
|
"is_active": True,
|
|
"type": "INTEGER",
|
|
"groupby": False,
|
|
"filterable": False,
|
|
"expression": "revenue-expenses",
|
|
}
|
|
],
|
|
"database_uuid": database.uuid,
|
|
}
|
|
|
|
schema = ImportV1DatasetSchema()
|
|
dataset_config = schema.load(yaml_config)
|
|
dataset_config["database_id"] = database.id
|
|
sqla_table = import_dataset(session, dataset_config)
|
|
|
|
assert sqla_table.extra == None
|
|
|
|
|
|
@patch("superset.datasets.commands.importers.v1.utils.request")
|
|
def test_import_column_allowed_data_url(
|
|
request: Mock,
|
|
mocker: MockFixture,
|
|
session: Session,
|
|
) -> None:
|
|
"""
|
|
Test importing a dataset when using data key to fetch data from a URL.
|
|
"""
|
|
import io
|
|
|
|
from superset import security_manager
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.datasets.schemas import ImportV1DatasetSchema
|
|
from superset.models.core import Database
|
|
|
|
request.urlopen.return_value = io.StringIO("col1\nvalue1\nvalue2\n")
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
dataset_uuid = uuid.uuid4()
|
|
yaml_config: dict[str, Any] = {
|
|
"version": "1.0.0",
|
|
"table_name": "my_table",
|
|
"main_dttm_col": "ds",
|
|
"description": "This is the description",
|
|
"default_endpoint": None,
|
|
"offset": -8,
|
|
"cache_timeout": 3600,
|
|
"schema": None,
|
|
"sql": None,
|
|
"params": {
|
|
"remote_id": 64,
|
|
"database_name": "examples",
|
|
"import_time": 1606677834,
|
|
},
|
|
"template_params": None,
|
|
"filter_select_enabled": True,
|
|
"fetch_values_predicate": None,
|
|
"extra": None,
|
|
"uuid": dataset_uuid,
|
|
"metrics": [],
|
|
"columns": [
|
|
{
|
|
"column_name": "col1",
|
|
"verbose_name": None,
|
|
"is_dttm": False,
|
|
"is_active": True,
|
|
"type": "TEXT",
|
|
"groupby": False,
|
|
"filterable": False,
|
|
"expression": None,
|
|
"description": None,
|
|
"python_date_format": None,
|
|
"extra": None,
|
|
}
|
|
],
|
|
"database_uuid": database.uuid,
|
|
"data": "https://some-external-url.com/data.csv",
|
|
}
|
|
|
|
# the Marshmallow schema should convert strings to objects
|
|
schema = ImportV1DatasetSchema()
|
|
dataset_config = schema.load(yaml_config)
|
|
dataset_config["database_id"] = database.id
|
|
_ = import_dataset(session, dataset_config, force_data=True)
|
|
session.connection()
|
|
assert [("value1",), ("value2",)] == session.execute(
|
|
"SELECT * FROM my_table"
|
|
).fetchall()
|
|
|
|
|
|
def test_import_dataset_managed_externally(
|
|
mocker: MockFixture,
|
|
session: Session,
|
|
) -> None:
|
|
"""
|
|
Test importing a dataset that is managed externally.
|
|
"""
|
|
from superset import security_manager
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.datasets.commands.importers.v1.utils import import_dataset
|
|
from superset.models.core import Database
|
|
from tests.integration_tests.fixtures.importexport import dataset_config
|
|
|
|
mocker.patch.object(security_manager, "can_access", return_value=True)
|
|
|
|
engine = session.get_bind()
|
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
|
|
|
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
|
session.add(database)
|
|
session.flush()
|
|
|
|
config = copy.deepcopy(dataset_config)
|
|
config["is_managed_externally"] = True
|
|
config["external_url"] = "https://example.org/my_table"
|
|
config["database_id"] = database.id
|
|
|
|
sqla_table = import_dataset(session, config)
|
|
assert sqla_table.is_managed_externally is True
|
|
assert sqla_table.external_url == "https://example.org/my_table"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"allowed_urls, data_uri, expected, exception_class",
|
|
[
|
|
([r".*"], "https://some-url/data.csv", True, None),
|
|
(
|
|
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
|
|
"https://host1.domain1.com/data.csv",
|
|
True,
|
|
None,
|
|
),
|
|
(
|
|
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
|
|
"https://host2.domain1.com/data.csv",
|
|
True,
|
|
None,
|
|
),
|
|
(
|
|
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
|
|
"https://host1.domain2.com/data.csv",
|
|
True,
|
|
None,
|
|
),
|
|
(
|
|
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
|
|
"https://host1.domain3.com/data.csv",
|
|
False,
|
|
DatasetForbiddenDataURI,
|
|
),
|
|
([], "https://host1.domain3.com/data.csv", False, DatasetForbiddenDataURI),
|
|
(["*"], "https://host1.domain3.com/data.csv", False, re.error),
|
|
],
|
|
)
|
|
def test_validate_data_uri(allowed_urls, data_uri, expected, exception_class):
|
|
current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] = allowed_urls
|
|
if expected:
|
|
validate_data_uri(data_uri)
|
|
else:
|
|
with pytest.raises(exception_class):
|
|
validate_data_uri(data_uri)
|