chore: consolidate datasource import logic (#11533)

* Consolidate dash import logic

* WIP

* Add license

* Fix lint

* Retrigger tests

* Fix lint
This commit is contained in:
Beto Dealmeida 2020-11-11 22:04:16 -08:00 committed by GitHub
parent 6ef4d2a991
commit 45738ffc1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 365 additions and 297 deletions

View File

@ -301,11 +301,11 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
) )
def import_datasources(path: str, sync: str, recursive: bool) -> None: def import_datasources(path: str, sync: str, recursive: bool) -> None:
"""Import datasources from YAML""" """Import datasources from YAML"""
from superset.utils import dict_import_export from superset.datasets.commands.importers.v0 import ImportDatasetsCommand
sync_array = sync.split(",") sync_array = sync.split(",")
path_object = Path(path) path_object = Path(path)
files = [] files: List[Path] = []
if path_object.is_file(): if path_object.is_file():
files.append(path_object) files.append(path_object)
elif path_object.exists() and not recursive: elif path_object.exists() and not recursive:
@ -314,16 +314,11 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
elif path_object.exists() and recursive: elif path_object.exists() and recursive:
files.extend(path_object.rglob("*.yaml")) files.extend(path_object.rglob("*.yaml"))
files.extend(path_object.rglob("*.yml")) files.extend(path_object.rglob("*.yml"))
for file_ in files: contents = {path.name: open(path).read() for path in files}
logger.info("Importing datasources from file %s", file_) try:
try: ImportDatasetsCommand(contents, sync_array).run()
with file_.open() as data_stream: except Exception: # pylint: disable=broad-except
dict_import_export.import_from_dict( logger.exception("Error when importing dataset")
db.session, yaml.safe_load(data_stream), sync=sync_array
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing datasources from file %s", file_)
logger.error(ex)
@superset.command() @superset.command()

View File

@ -56,7 +56,7 @@ from superset.exceptions import SupersetException
from superset.models.core import Database from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult
from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource from superset.utils import core as utils
try: try:
import requests import requests
@ -378,20 +378,6 @@ class DruidColumn(Model, BaseColumn):
metric.datasource_id = self.datasource_id metric.datasource_id = self.datasource_id
db.session.add(metric) db.session.add(metric)
@classmethod
def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]:
return (
db.session.query(DruidColumn)
.filter(
DruidColumn.datasource_id == lookup_column.datasource_id,
DruidColumn.column_name == lookup_column.column_name,
)
.first()
)
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
class DruidMetric(Model, BaseMetric): class DruidMetric(Model, BaseMetric):
@ -447,20 +433,6 @@ class DruidMetric(Model, BaseMetric):
def get_perm(self) -> Optional[str]: def get_perm(self) -> Optional[str]:
return self.perm return self.perm
@classmethod
def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric":
def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
return (
db.session.query(DruidMetric)
.filter(
DruidMetric.datasource_id == lookup_metric.datasource_id,
DruidMetric.metric_name == lookup_metric.metric_name,
)
.first()
)
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
druiddatasource_user = Table( druiddatasource_user = Table(
"druiddatasource_user", "druiddatasource_user",
@ -610,34 +582,6 @@ class DruidDatasource(Model, BaseDatasource):
def get_metric_obj(self, metric_name: str) -> Dict[str, Any]: def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod
def import_obj(
cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overridden if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""
def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]:
return (
db.session.query(DruidDatasource)
.filter(
DruidDatasource.datasource_name == d.datasource_name,
DruidDatasource.cluster_id == d.cluster_id,
)
.first()
)
def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
return import_datasource.import_datasource(
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
)
def latest_metadata(self) -> Optional[Dict[str, Any]]: def latest_metadata(self) -> Optional[Dict[str, Any]]:
"""Returns segment metadata from the latest segment""" """Returns segment metadata from the latest segment"""
logger.info("Syncing datasource [{}]".format(self.datasource_name)) logger.info("Syncing datasource [{}]".format(self.datasource_name))

View File

@ -47,7 +47,6 @@ from sqlalchemy import (
) )
from sqlalchemy.exc import CompileError from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.expression import Label, Select, TextAsFrom
@ -58,11 +57,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr
from superset.constants import NULL_STRING from superset.constants import NULL_STRING
from superset.db_engine_specs.base import TimestampExpression from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
DatabaseNotFound,
QueryObjectValidationError,
SupersetSecurityException,
)
from superset.jinja_context import ( from superset.jinja_context import (
BaseTemplateProcessor, BaseTemplateProcessor,
ExtraCache, ExtraCache,
@ -74,7 +69,7 @@ from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource from superset.utils import core as utils
config = app.config config = app.config
metadata = Model.metadata # pylint: disable=no-member metadata = Model.metadata # pylint: disable=no-member
@ -290,20 +285,6 @@ class TableColumn(Model, BaseColumn):
) )
return self.table.make_sqla_column_compatible(time_expr, label) return self.table.make_sqla_column_compatible(time_expr, label)
@classmethod
def import_obj(cls, i_column: "TableColumn") -> "TableColumn":
def lookup_obj(lookup_column: TableColumn) -> TableColumn:
return (
db.session.query(TableColumn)
.filter(
TableColumn.table_id == lookup_column.table_id,
TableColumn.column_name == lookup_column.column_name,
)
.first()
)
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
def dttm_sql_literal( def dttm_sql_literal(
self, self,
dttm: DateTime, dttm: DateTime,
@ -412,20 +393,6 @@ class SqlMetric(Model, BaseMetric):
def get_perm(self) -> Optional[str]: def get_perm(self) -> Optional[str]:
return self.perm return self.perm
@classmethod
def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric":
def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric:
return (
db.session.query(SqlMetric)
.filter(
SqlMetric.table_id == lookup_metric.table_id,
SqlMetric.metric_name == lookup_metric.metric_name,
)
.first()
)
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
def get_extra_dict(self) -> Dict[str, Any]: def get_extra_dict(self) -> Dict[str, Any]:
try: try:
return json.loads(self.extra) return json.loads(self.extra)
@ -1416,56 +1383,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
db.session.commit() db.session.commit()
return results return results
@classmethod
def import_obj(
cls,
i_datasource: "SqlaTable",
database_id: Optional[int] = None,
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""
def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable":
return (
db.session.query(SqlaTable)
.join(Database)
.filter(
SqlaTable.table_name == table_.table_name,
SqlaTable.schema == table_.schema,
Database.id == table_.database_id,
)
.first()
)
def lookup_database(table_: SqlaTable) -> Database:
try:
return (
db.session.query(Database)
.filter_by(database_name=table_.params_dict["database_name"])
.one()
)
except NoResultFound:
raise DatabaseNotFound(
_(
"Database '%(name)s' is not found",
name=table_.params_dict["database_name"],
)
)
return import_datasource.import_datasource(
db.session,
i_datasource,
lookup_database,
lookup_sqlatable,
import_time,
database_id,
)
@classmethod @classmethod
def query_datasources_by_name( def query_datasources_by_name(
cls, cls,

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -19,7 +19,7 @@ import logging
import time import time
from copy import copy from copy import copy
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session from sqlalchemy.orm import make_transient, Session
@ -27,6 +27,7 @@ from sqlalchemy.orm import make_transient, Session
from superset import ConnectorRegistry, db from superset import ConnectorRegistry, db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.datasets.commands.importers.v0 import import_dataset
from superset.exceptions import DashboardImportException from superset.exceptions import DashboardImportException
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
@ -301,7 +302,7 @@ def import_dashboards(
if not data: if not data:
raise DashboardImportException(_("No data in file")) raise DashboardImportException(_("No data in file"))
for table in data["datasources"]: for table in data["datasources"]:
type(table).import_obj(table, database_id, import_time=import_time) import_dataset(table, database_id, import_time=import_time)
session.commit() session.commit()
for dashboard in data["dashboards"]: for dashboard in data["dashboards"]:
import_dashboard(dashboard, import_time=import_time) import_dashboard(dashboard, import_time=import_time)
@ -333,4 +334,5 @@ class ImportDashboardsCommand(BaseCommand):
try: try:
json.loads(content) json.loads(content)
except ValueError: except ValueError:
logger.exception("Invalid JSON file")
raise raise

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,303 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Callable, Dict, List, Optional
import yaml
from flask_appbuilder import Model
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import make_transient
from superset import db
from superset.commands.base import BaseCommand
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.connectors.druid.models import (
DruidCluster,
DruidColumn,
DruidDatasource,
DruidMetric,
)
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.models.core import Database
from superset.utils.dict_import_export import DATABASES_KEY, DRUID_CLUSTERS_KEY
logger = logging.getLogger(__name__)
def lookup_sqla_table(table: SqlaTable) -> Optional[SqlaTable]:
return (
db.session.query(SqlaTable)
.join(Database)
.filter(
SqlaTable.table_name == table.table_name,
SqlaTable.schema == table.schema,
Database.id == table.database_id,
)
.first()
)
def lookup_sqla_database(table: SqlaTable) -> Optional[Database]:
try:
return (
db.session.query(Database)
.filter_by(database_name=table.params_dict["database_name"])
.one()
)
except NoResultFound:
raise DatabaseNotFoundError
def lookup_druid_cluster(datasource: DruidDatasource) -> Optional[DruidCluster]:
return db.session.query(DruidCluster).filter_by(id=datasource.cluster_id).first()
def lookup_druid_datasource(datasource: DruidDatasource) -> Optional[DruidDatasource]:
return (
db.session.query(DruidDatasource)
.filter(
DruidDatasource.datasource_name == datasource.datasource_name,
DruidDatasource.cluster_id == datasource.cluster_id,
)
.first()
)
def import_dataset(
i_datasource: BaseDatasource,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overridden if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copied over.
"""
lookup_database: Callable[[BaseDatasource], Optional[Database]]
lookup_datasource: Callable[[BaseDatasource], Optional[BaseDatasource]]
if isinstance(i_datasource, SqlaTable):
lookup_database = lookup_sqla_database
lookup_datasource = lookup_sqla_table
else:
lookup_database = lookup_druid_cluster
lookup_datasource = lookup_druid_datasource
return import_datasource(
db.session,
i_datasource,
lookup_database,
lookup_datasource,
import_time,
database_id,
)
def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
return (
session.query(SqlMetric)
.filter(
SqlMetric.table_id == metric.table_id,
SqlMetric.metric_name == metric.metric_name,
)
.first()
)
def lookup_druid_metric(session: Session, metric: DruidMetric) -> DruidMetric:
return (
session.query(DruidMetric)
.filter(
DruidMetric.datasource_id == metric.datasource_id,
DruidMetric.metric_name == metric.metric_name,
)
.first()
)
def import_metric(session: Session, metric: BaseMetric) -> BaseMetric:
if isinstance(metric, SqlMetric):
lookup_metric = lookup_sqla_metric
else:
lookup_metric = lookup_druid_metric
return import_simple_obj(session, metric, lookup_metric)
def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
return (
session.query(TableColumn)
.filter(
TableColumn.table_id == column.table_id,
TableColumn.column_name == column.column_name,
)
.first()
)
def lookup_druid_column(session: Session, column: DruidColumn) -> DruidColumn:
return (
session.query(DruidColumn)
.filter(
DruidColumn.datasource_id == column.datasource_id,
DruidColumn.column_name == column.column_name,
)
.first()
)
def import_column(session: Session, column: BaseColumn) -> BaseColumn:
if isinstance(column, TableColumn):
lookup_column = lookup_sqla_column
else:
lookup_column = lookup_druid_column
return import_simple_obj(session, column, lookup_column)
def import_datasource( # pylint: disable=too-many-arguments
session: Session,
i_datasource: Model,
lookup_database: Callable[[Model], Optional[Model]],
lookup_datasource: Callable[[Model], Optional[Model]],
import_time: Optional[int] = None,
database_id: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export datasources between multiple
superset instances. Audit metadata isn't copies over.
"""
make_transient(i_datasource)
logger.info("Started import of the datasource: %s", i_datasource.to_json())
i_datasource.id = None
i_datasource.database_id = (
database_id
if database_id
else getattr(lookup_database(i_datasource), "id", None)
)
i_datasource.alter_params(import_time=import_time)
# override the datasource
datasource = lookup_datasource(i_datasource)
if datasource:
datasource.override(i_datasource)
session.flush()
else:
datasource = i_datasource.copy()
session.add(datasource)
session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
new_m.table_id = datasource.id
logger.info(
"Importing metric %s from the datasource: %s",
new_m.to_json(),
i_datasource.full_name,
)
imported_m = import_metric(session, new_m)
if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
datasource.metrics.append(imported_m)
for column in i_datasource.columns:
new_c = column.copy()
new_c.table_id = datasource.id
logger.info(
"Importing column %s from the datasource: %s",
new_c.to_json(),
i_datasource.full_name,
)
imported_c = import_column(session, new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
session.flush()
return datasource.id
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model]
) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
# find if the column was already imported
existing_column = lookup_obj(session, i_obj)
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
session.flush()
return existing_column
session.add(i_obj)
session.flush()
return i_obj
def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(session, database, sync=sync)
logger.info(
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
)
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
DruidCluster.import_from_dict(session, datasource, sync=sync)
session.commit()
else:
logger.info("Supplied object is not a dictionary.")
class ImportDatasetsCommand(BaseCommand):
"""
Import datasources in YAML format.
This is the original unversioned format used to export and import datasources
in Superset.
"""
def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None):
self.contents = contents
self.sync = sync
def run(self) -> None:
self.validate()
for file_name, content in self.contents.items():
logger.info("Importing dataset from file %s", file_name)
import_from_dict(db.session, yaml.safe_load(content), sync=self.sync)
def validate(self) -> None:
# ensure all files are YAML
for content in self.contents.values():
try:
yaml.safe_load(content)
except yaml.parser.ParserError:
logger.exception("Invalid YAML file")
raise

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -75,24 +75,3 @@ def export_to_dict(
if clusters: if clusters:
data[DRUID_CLUSTERS_KEY] = clusters data[DRUID_CLUSTERS_KEY] = clusters
return data return data
def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(session, database, sync=sync)
logger.info(
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
)
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
DruidCluster.import_from_dict(session, datasource, sync=sync)
session.commit()
else:
logger.info("Supplied object is not a dictionary.")

View File

@ -1,105 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Callable, Optional
from flask_appbuilder import Model
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
logger = logging.getLogger(__name__)
def import_datasource( # pylint: disable=too-many-arguments
session: Session,
i_datasource: Model,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
import_time: Optional[int] = None,
database_id: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export datasources between multiple
superset instances. Audit metadata isn't copies over.
"""
make_transient(i_datasource)
logger.info("Started import of the datasource: %s", i_datasource.to_json())
i_datasource.id = None
i_datasource.database_id = (
database_id if database_id else lookup_database(i_datasource).id
)
i_datasource.alter_params(import_time=import_time)
# override the datasource
datasource = lookup_datasource(i_datasource)
if datasource:
datasource.override(i_datasource)
session.flush()
else:
datasource = i_datasource.copy()
session.add(datasource)
session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
new_m.table_id = datasource.id
logger.info(
"Importing metric %s from the datasource: %s",
new_m.to_json(),
i_datasource.full_name,
)
imported_m = i_datasource.metric_class.import_obj(new_m)
if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
datasource.metrics.append(imported_m)
for column in i_datasource.columns:
new_c = column.copy()
new_c.table_id = datasource.id
logger.info(
"Importing column %s from the datasource: %s",
new_c.to_json(),
i_datasource.full_name,
)
imported_c = i_datasource.column_class.import_obj(new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
session.flush()
return datasource.id
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
# find if the column was already imported
existing_column = lookup_obj(i_obj)
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
session.flush()
return existing_column
session.add(i_obj)
session.flush()
return i_obj

View File

@ -33,6 +33,7 @@ from superset.connectors.druid.models import (
) )
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard
from superset.datasets.commands.importers.v0 import import_dataset
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import get_example_database from superset.utils.core import get_example_database
@ -567,7 +568,7 @@ class TestImportExport(SupersetTestCase):
def test_import_table_no_metadata(self): def test_import_table_no_metadata(self):
db_id = get_example_database().id db_id = get_example_database().id
table = self.create_table("pure_table", id=10001) table = self.create_table("pure_table", id=10001)
imported_id = SqlaTable.import_obj(table, db_id, import_time=1989) imported_id = import_dataset(table, db_id, import_time=1989)
imported = self.get_table_by_id(imported_id) imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported) self.assert_table_equals(table, imported)
@ -576,7 +577,7 @@ class TestImportExport(SupersetTestCase):
"table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] "table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
) )
db_id = get_example_database().id db_id = get_example_database().id
imported_id = SqlaTable.import_obj(table, db_id, import_time=1990) imported_id = import_dataset(table, db_id, import_time=1990)
imported = self.get_table_by_id(imported_id) imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported) self.assert_table_equals(table, imported)
self.assertEqual( self.assertEqual(
@ -592,7 +593,7 @@ class TestImportExport(SupersetTestCase):
metric_names=["m1", "m2"], metric_names=["m1", "m2"],
) )
db_id = get_example_database().id db_id = get_example_database().id
imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) imported_id = import_dataset(table, db_id, import_time=1991)
imported = self.get_table_by_id(imported_id) imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported) self.assert_table_equals(table, imported)
@ -602,7 +603,7 @@ class TestImportExport(SupersetTestCase):
"table_override", id=10003, cols_names=["col1"], metric_names=["m1"] "table_override", id=10003, cols_names=["col1"], metric_names=["m1"]
) )
db_id = get_example_database().id db_id = get_example_database().id
imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) imported_id = import_dataset(table, db_id, import_time=1991)
table_over = self.create_table( table_over = self.create_table(
"table_override", "table_override",
@ -610,7 +611,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"], cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
imported_over_id = SqlaTable.import_obj(table_over, db_id, import_time=1992) imported_over_id = import_dataset(table_over, db_id, import_time=1992)
imported_over = self.get_table_by_id(imported_over_id) imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id) self.assertEqual(imported_id, imported_over.id)
@ -630,7 +631,7 @@ class TestImportExport(SupersetTestCase):
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
db_id = get_example_database().id db_id = get_example_database().id
imported_id = SqlaTable.import_obj(table, db_id, import_time=1993) imported_id = import_dataset(table, db_id, import_time=1993)
copy_table = self.create_table( copy_table = self.create_table(
"copy_cat", "copy_cat",
@ -638,14 +639,14 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"], cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
imported_id_copy = SqlaTable.import_obj(copy_table, db_id, import_time=1994) imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
self.assertEqual(imported_id, imported_id_copy) self.assertEqual(imported_id, imported_id_copy)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
def test_import_druid_no_metadata(self): def test_import_druid_no_metadata(self):
datasource = self.create_druid_datasource("pure_druid", id=10001) datasource = self.create_druid_datasource("pure_druid", id=10001)
imported_id = DruidDatasource.import_obj(datasource, import_time=1989) imported_id = import_dataset(datasource, import_time=1989)
imported = self.get_datasource(imported_id) imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported) self.assert_datasource_equals(datasource, imported)
@ -653,7 +654,7 @@ class TestImportExport(SupersetTestCase):
datasource = self.create_druid_datasource( datasource = self.create_druid_datasource(
"druid_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] "druid_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
) )
imported_id = DruidDatasource.import_obj(datasource, import_time=1990) imported_id = import_dataset(datasource, import_time=1990)
imported = self.get_datasource(imported_id) imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported) self.assert_datasource_equals(datasource, imported)
self.assertEqual( self.assertEqual(
@ -668,7 +669,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["c1", "c2"], cols_names=["c1", "c2"],
metric_names=["m1", "m2"], metric_names=["m1", "m2"],
) )
imported_id = DruidDatasource.import_obj(datasource, import_time=1991) imported_id = import_dataset(datasource, import_time=1991)
imported = self.get_datasource(imported_id) imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported) self.assert_datasource_equals(datasource, imported)
@ -676,14 +677,14 @@ class TestImportExport(SupersetTestCase):
datasource = self.create_druid_datasource( datasource = self.create_druid_datasource(
"druid_override", id=10004, cols_names=["col1"], metric_names=["m1"] "druid_override", id=10004, cols_names=["col1"], metric_names=["m1"]
) )
imported_id = DruidDatasource.import_obj(datasource, import_time=1991) imported_id = import_dataset(datasource, import_time=1991)
table_over = self.create_druid_datasource( table_over = self.create_druid_datasource(
"druid_override", "druid_override",
id=10004, id=10004,
cols_names=["new_col1", "col2", "col3"], cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
imported_over_id = DruidDatasource.import_obj(table_over, import_time=1992) imported_over_id = import_dataset(table_over, import_time=1992)
imported_over = self.get_datasource(imported_over_id) imported_over = self.get_datasource(imported_over_id)
self.assertEqual(imported_id, imported_over.id) self.assertEqual(imported_id, imported_over.id)
@ -702,7 +703,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"], cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
imported_id = DruidDatasource.import_obj(datasource, import_time=1993) imported_id = import_dataset(datasource, import_time=1993)
copy_datasource = self.create_druid_datasource( copy_datasource = self.create_druid_datasource(
"copy_cat", "copy_cat",
@ -710,7 +711,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"], cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"], metric_names=["new_metric1"],
) )
imported_id_copy = DruidDatasource.import_obj(copy_datasource, import_time=1994) imported_id_copy = import_dataset(copy_datasource, import_time=1994)
self.assertEqual(imported_id, imported_id_copy) self.assertEqual(imported_id, imported_id_copy)
self.assert_datasource_equals(copy_datasource, self.get_datasource(imported_id)) self.assert_datasource_equals(copy_datasource, self.get_datasource(imported_id))