diff --git a/superset/cli.py b/superset/cli.py index 6d1d2fb92c..5130dbfffb 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -301,11 +301,11 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None: ) def import_datasources(path: str, sync: str, recursive: bool) -> None: """Import datasources from YAML""" - from superset.utils import dict_import_export + from superset.datasets.commands.importers.v0 import ImportDatasetsCommand sync_array = sync.split(",") path_object = Path(path) - files = [] + files: List[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: @@ -314,16 +314,11 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None: elif path_object.exists() and recursive: files.extend(path_object.rglob("*.yaml")) files.extend(path_object.rglob("*.yml")) - for file_ in files: - logger.info("Importing datasources from file %s", file_) - try: - with file_.open() as data_stream: - dict_import_export.import_from_dict( - db.session, yaml.safe_load(data_stream), sync=sync_array - ) - except Exception as ex: # pylint: disable=broad-except - logger.error("Error when importing datasources from file %s", file_) - logger.error(ex) + contents = {path.name: open(path).read() for path in files} + try: + ImportDatasetsCommand(contents, sync_array).run() + except Exception: # pylint: disable=broad-except + logger.exception("Error when importing dataset") @superset.command() diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index a8d00fd4a1..644d6a345e 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -56,7 +56,7 @@ from superset.exceptions import SupersetException from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict -from superset.utils import core as utils, import_datasource +from superset.utils import core as utils try: import requests @@ -378,20 +378,6 @@ class DruidColumn(Model, BaseColumn): metric.datasource_id = self.datasource_id db.session.add(metric) - @classmethod - def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn": - def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]: - return ( - db.session.query(DruidColumn) - .filter( - DruidColumn.datasource_id == lookup_column.datasource_id, - DruidColumn.column_name == lookup_column.column_name, - ) - .first() - ) - - return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) - class DruidMetric(Model, BaseMetric): @@ -447,20 +433,6 @@ class DruidMetric(Model, BaseMetric): def get_perm(self) -> Optional[str]: return self.perm - @classmethod - def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric": - def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]: - return ( - db.session.query(DruidMetric) - .filter( - DruidMetric.datasource_id == lookup_metric.datasource_id, - DruidMetric.metric_name == lookup_metric.metric_name, - ) - .first() - ) - - return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj) - druiddatasource_user = Table( "druiddatasource_user", @@ -610,34 +582,6 @@ class DruidDatasource(Model, BaseDatasource): def get_metric_obj(self, metric_name: str) -> Dict[str, Any]: return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] - @classmethod - def import_obj( - cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None - ) -> int: - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overridden if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - - def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]: - return ( - db.session.query(DruidDatasource) - .filter( - DruidDatasource.datasource_name == d.datasource_name, - DruidDatasource.cluster_id == d.cluster_id, - ) - .first() - ) - - def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]: - return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first() - - return import_datasource.import_datasource( - db.session, i_datasource, lookup_cluster, lookup_datasource, import_time - ) - def latest_metadata(self) -> Optional[Dict[str, Any]]: """Returns segment metadata from the latest segment""" logger.info("Syncing datasource [{}]".format(self.datasource_name)) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a7c078d006..34961f0b4b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -47,7 +47,6 @@ from sqlalchemy import ( ) from sqlalchemy.exc import CompileError from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session -from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.expression import Label, Select, TextAsFrom @@ -58,11 +57,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import ( - DatabaseNotFound, - QueryObjectValidationError, - SupersetSecurityException, -) +from superset.exceptions import QueryObjectValidationError, SupersetSecurityException from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -74,7 +69,7 @@ from superset.models.helpers import AuditMixinNullable, QueryResult from superset.result_set import SupersetResultSet from superset.sql_parse import ParsedQuery from superset.typing import Metric, QueryObjectDict -from superset.utils import core as utils, import_datasource +from superset.utils import core as utils config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -290,20 +285,6 @@ class TableColumn(Model, BaseColumn): ) return self.table.make_sqla_column_compatible(time_expr, label) - @classmethod - def import_obj(cls, i_column: "TableColumn") -> "TableColumn": - def lookup_obj(lookup_column: TableColumn) -> TableColumn: - return ( - db.session.query(TableColumn) - .filter( - TableColumn.table_id == lookup_column.table_id, - TableColumn.column_name == lookup_column.column_name, - ) - .first() - ) - - return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) - def dttm_sql_literal( self, dttm: DateTime, @@ -412,20 +393,6 @@ class SqlMetric(Model, BaseMetric): def get_perm(self) -> Optional[str]: return self.perm - @classmethod - def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric": - def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric: - return ( - db.session.query(SqlMetric) - .filter( - SqlMetric.table_id == lookup_metric.table_id, - SqlMetric.metric_name == lookup_metric.metric_name, - ) - .first() - ) - - return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj) - def get_extra_dict(self) -> Dict[str, Any]: try: return json.loads(self.extra) @@ -1416,56 +1383,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at db.session.commit() return results - @classmethod - def import_obj( - cls, - i_datasource: "SqlaTable", - database_id: Optional[int] = None, - import_time: Optional[int] = None, - ) -> int: - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - - def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable": - return ( - db.session.query(SqlaTable) - .join(Database) - .filter( - SqlaTable.table_name == table_.table_name, - SqlaTable.schema == table_.schema, - Database.id == table_.database_id, - ) - .first() - ) - - def lookup_database(table_: SqlaTable) -> Database: - try: - return ( - db.session.query(Database) - .filter_by(database_name=table_.params_dict["database_name"]) - .one() - ) - except NoResultFound: - raise DatabaseNotFound( - _( - "Database '%(name)s' is not found", - name=table_.params_dict["database_name"], - ) - ) - - return import_datasource.import_datasource( - db.session, - i_datasource, - lookup_database, - lookup_sqlatable, - import_time, - database_id, - ) - @classmethod def query_datasources_by_name( cls, diff --git a/superset/dashboards/commands/importers/__init__.py b/superset/dashboards/commands/importers/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/dashboards/commands/importers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index 8040c248b8..851ecab941 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -19,7 +19,7 @@ import logging import time from copy import copy from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from flask_babel import lazy_gettext as _ from sqlalchemy.orm import make_transient, Session @@ -27,6 +27,7 @@ from sqlalchemy.orm import make_transient, Session from superset import ConnectorRegistry, db from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.datasets.commands.importers.v0 import import_dataset from superset.exceptions import DashboardImportException from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -301,7 +302,7 @@ def import_dashboards( if not data: raise DashboardImportException(_("No data in file")) for table in data["datasources"]: - type(table).import_obj(table, database_id, import_time=import_time) + import_dataset(table, database_id, import_time=import_time) session.commit() for dashboard in data["dashboards"]: import_dashboard(dashboard, import_time=import_time) @@ -333,4 +334,5 @@ class ImportDashboardsCommand(BaseCommand): try: json.loads(content) except ValueError: + logger.exception("Invalid JSON file") raise diff --git a/superset/datasets/commands/importers/__init__.py b/superset/datasets/commands/importers/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/datasets/commands/importers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py new file mode 100644 index 0000000000..5b3ed25d73 --- /dev/null +++ b/superset/datasets/commands/importers/v0.py @@ -0,0 +1,303 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Callable, Dict, List, Optional + +import yaml +from flask_appbuilder import Model +from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.session import make_transient + +from superset import db +from superset.commands.base import BaseCommand +from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric +from superset.connectors.druid.models import ( + DruidCluster, + DruidColumn, + DruidDatasource, + DruidMetric, +) +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.models.core import Database +from superset.utils.dict_import_export import DATABASES_KEY, DRUID_CLUSTERS_KEY + +logger = logging.getLogger(__name__) + + +def lookup_sqla_table(table: SqlaTable) -> Optional[SqlaTable]: + return ( + db.session.query(SqlaTable) + .join(Database) + .filter( + SqlaTable.table_name == table.table_name, + SqlaTable.schema == table.schema, + Database.id == table.database_id, + ) + .first() + ) + + +def lookup_sqla_database(table: SqlaTable) -> Optional[Database]: + try: + return ( + db.session.query(Database) + .filter_by(database_name=table.params_dict["database_name"]) + .one() + ) + except NoResultFound: + raise DatabaseNotFoundError + + +def lookup_druid_cluster(datasource: DruidDatasource) -> Optional[DruidCluster]: + return db.session.query(DruidCluster).filter_by(id=datasource.cluster_id).first() + + +def lookup_druid_datasource(datasource: DruidDatasource) -> Optional[DruidDatasource]: + return ( + db.session.query(DruidDatasource) + .filter( + DruidDatasource.datasource_name == datasource.datasource_name, + DruidDatasource.cluster_id == datasource.cluster_id, + ) + .first() + ) + + +def import_dataset( + i_datasource: BaseDatasource, + database_id: Optional[int] = None, + import_time: Optional[int] = None, +) -> int: + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overridden if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copied over. + """ + + lookup_database: Callable[[BaseDatasource], Optional[Database]] + lookup_datasource: Callable[[BaseDatasource], Optional[BaseDatasource]] + if isinstance(i_datasource, SqlaTable): + lookup_database = lookup_sqla_database + lookup_datasource = lookup_sqla_table + else: + lookup_database = lookup_druid_cluster + lookup_datasource = lookup_druid_datasource + + return import_datasource( + db.session, + i_datasource, + lookup_database, + lookup_datasource, + import_time, + database_id, + ) + + +def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric: + return ( + session.query(SqlMetric) + .filter( + SqlMetric.table_id == metric.table_id, + SqlMetric.metric_name == metric.metric_name, + ) + .first() + ) + + +def lookup_druid_metric(session: Session, metric: DruidMetric) -> DruidMetric: + return ( + session.query(DruidMetric) + .filter( + DruidMetric.datasource_id == metric.datasource_id, + DruidMetric.metric_name == metric.metric_name, + ) + .first() + ) + + +def import_metric(session: Session, metric: BaseMetric) -> BaseMetric: + if isinstance(metric, SqlMetric): + lookup_metric = lookup_sqla_metric + else: + lookup_metric = lookup_druid_metric + return import_simple_obj(session, metric, lookup_metric) + + +def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: + return ( + session.query(TableColumn) + .filter( + TableColumn.table_id == column.table_id, + TableColumn.column_name == column.column_name, + ) + .first() + ) + + +def lookup_druid_column(session: Session, column: DruidColumn) -> DruidColumn: + return ( + session.query(DruidColumn) + .filter( + DruidColumn.datasource_id == column.datasource_id, + DruidColumn.column_name == column.column_name, + ) + .first() + ) + + +def import_column(session: Session, column: BaseColumn) -> BaseColumn: + if isinstance(column, TableColumn): + lookup_column = lookup_sqla_column + else: + lookup_column = lookup_druid_column + return import_simple_obj(session, column, lookup_column) + + +def import_datasource( # pylint: disable=too-many-arguments + session: Session, + i_datasource: Model, + lookup_database: Callable[[Model], Optional[Model]], + lookup_datasource: Callable[[Model], Optional[Model]], + import_time: Optional[int] = None, + database_id: Optional[int] = None, +) -> int: + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export datasources between multiple + superset instances. Audit metadata isn't copies over. + """ + make_transient(i_datasource) + logger.info("Started import of the datasource: %s", i_datasource.to_json()) + + i_datasource.id = None + i_datasource.database_id = ( + database_id + if database_id + else getattr(lookup_database(i_datasource), "id", None) + ) + i_datasource.alter_params(import_time=import_time) + + # override the datasource + datasource = lookup_datasource(i_datasource) + + if datasource: + datasource.override(i_datasource) + session.flush() + else: + datasource = i_datasource.copy() + session.add(datasource) + session.flush() + + for metric in i_datasource.metrics: + new_m = metric.copy() + new_m.table_id = datasource.id + logger.info( + "Importing metric %s from the datasource: %s", + new_m.to_json(), + i_datasource.full_name, + ) + imported_m = import_metric(session, new_m) + if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]: + datasource.metrics.append(imported_m) + + for column in i_datasource.columns: + new_c = column.copy() + new_c.table_id = datasource.id + logger.info( + "Importing column %s from the datasource: %s", + new_c.to_json(), + i_datasource.full_name, + ) + imported_c = import_column(session, new_c) + if imported_c.column_name not in [c.column_name for c in datasource.columns]: + datasource.columns.append(imported_c) + session.flush() + return datasource.id + + +def import_simple_obj( + session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model] +) -> Model: + make_transient(i_obj) + i_obj.id = None + i_obj.table = None + + # find if the column was already imported + existing_column = lookup_obj(session, i_obj) + i_obj.table = None + if existing_column: + existing_column.override(i_obj) + session.flush() + return existing_column + + session.add(i_obj) + session.flush() + return i_obj + + +def import_from_dict( + session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None +) -> None: + """Imports databases and druid clusters from dictionary""" + if not sync: + sync = [] + if isinstance(data, dict): + logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) + for database in data.get(DATABASES_KEY, []): + Database.import_from_dict(session, database, sync=sync) + + logger.info( + "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY + ) + for datasource in data.get(DRUID_CLUSTERS_KEY, []): + DruidCluster.import_from_dict(session, datasource, sync=sync) + session.commit() + else: + logger.info("Supplied object is not a dictionary.") + + +class ImportDatasetsCommand(BaseCommand): + """ + Import datasources in YAML format. + + This is the original unversioned format used to export and import datasources + in Superset. + """ + + def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None): + self.contents = contents + self.sync = sync + + def run(self) -> None: + self.validate() + + for file_name, content in self.contents.items(): + logger.info("Importing dataset from file %s", file_name) + import_from_dict(db.session, yaml.safe_load(content), sync=self.sync) + + def validate(self) -> None: + # ensure all files are YAML + for content in self.contents.values(): + try: + yaml.safe_load(content) + except yaml.parser.ParserError: + logger.exception("Invalid YAML file") + raise diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 3a37e91c87..256f1b5655 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict from sqlalchemy.orm import Session @@ -75,24 +75,3 @@ def export_to_dict( if clusters: data[DRUID_CLUSTERS_KEY] = clusters return data - - -def import_from_dict( - session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None -) -> None: - """Imports databases and druid clusters from dictionary""" - if not sync: - sync = [] - if isinstance(data, dict): - logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) - for database in data.get(DATABASES_KEY, []): - Database.import_from_dict(session, database, sync=sync) - - logger.info( - "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY - ) - for datasource in data.get(DRUID_CLUSTERS_KEY, []): - DruidCluster.import_from_dict(session, datasource, sync=sync) - session.commit() - else: - logger.info("Supplied object is not a dictionary.") diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py deleted file mode 100644 index 25da876b28..0000000000 --- a/superset/utils/import_datasource.py +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from typing import Callable, Optional - -from flask_appbuilder import Model -from sqlalchemy.orm import Session -from sqlalchemy.orm.session import make_transient - -logger = logging.getLogger(__name__) - - -def import_datasource( # pylint: disable=too-many-arguments - session: Session, - i_datasource: Model, - lookup_database: Callable[[Model], Model], - lookup_datasource: Callable[[Model], Model], - import_time: Optional[int] = None, - database_id: Optional[int] = None, -) -> int: - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export datasources between multiple - superset instances. Audit metadata isn't copies over. - """ - make_transient(i_datasource) - logger.info("Started import of the datasource: %s", i_datasource.to_json()) - - i_datasource.id = None - i_datasource.database_id = ( - database_id if database_id else lookup_database(i_datasource).id - ) - i_datasource.alter_params(import_time=import_time) - - # override the datasource - datasource = lookup_datasource(i_datasource) - - if datasource: - datasource.override(i_datasource) - session.flush() - else: - datasource = i_datasource.copy() - session.add(datasource) - session.flush() - - for metric in i_datasource.metrics: - new_m = metric.copy() - new_m.table_id = datasource.id - logger.info( - "Importing metric %s from the datasource: %s", - new_m.to_json(), - i_datasource.full_name, - ) - imported_m = i_datasource.metric_class.import_obj(new_m) - if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]: - datasource.metrics.append(imported_m) - - for column in i_datasource.columns: - new_c = column.copy() - new_c.table_id = datasource.id - logger.info( - "Importing column %s from the datasource: %s", - new_c.to_json(), - i_datasource.full_name, - ) - imported_c = i_datasource.column_class.import_obj(new_c) - if imported_c.column_name not in [c.column_name for c in datasource.columns]: - datasource.columns.append(imported_c) - session.flush() - return datasource.id - - -def import_simple_obj( - session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model] -) -> Model: - make_transient(i_obj) - i_obj.id = None - i_obj.table = None - - # find if the column was already imported - existing_column = lookup_obj(i_obj) - i_obj.table = None - if existing_column: - existing_column.override(i_obj) - session.flush() - return existing_column - - session.add(i_obj) - session.flush() - return i_obj diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index c19b957cb1..aac3a5162b 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -33,6 +33,7 @@ from superset.connectors.druid.models import ( ) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard +from superset.datasets.commands.importers.v0 import import_dataset from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils.core import get_example_database @@ -567,7 +568,7 @@ class TestImportExport(SupersetTestCase): def test_import_table_no_metadata(self): db_id = get_example_database().id table = self.create_table("pure_table", id=10001) - imported_id = SqlaTable.import_obj(table, db_id, import_time=1989) + imported_id = import_dataset(table, db_id, import_time=1989) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) @@ -576,7 +577,7 @@ class TestImportExport(SupersetTestCase): "table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] ) db_id = get_example_database().id - imported_id = SqlaTable.import_obj(table, db_id, import_time=1990) + imported_id = import_dataset(table, db_id, import_time=1990) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) self.assertEqual( @@ -592,7 +593,7 @@ class TestImportExport(SupersetTestCase): metric_names=["m1", "m2"], ) db_id = get_example_database().id - imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) + imported_id = import_dataset(table, db_id, import_time=1991) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) @@ -602,7 +603,7 @@ class TestImportExport(SupersetTestCase): "table_override", id=10003, cols_names=["col1"], metric_names=["m1"] ) db_id = get_example_database().id - imported_id = SqlaTable.import_obj(table, db_id, import_time=1991) + imported_id = import_dataset(table, db_id, import_time=1991) table_over = self.create_table( "table_override", @@ -610,7 +611,7 @@ class TestImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_id = SqlaTable.import_obj(table_over, db_id, import_time=1992) + imported_over_id = import_dataset(table_over, db_id, import_time=1992) imported_over = self.get_table_by_id(imported_over_id) self.assertEqual(imported_id, imported_over.id) @@ -630,7 +631,7 @@ class TestImportExport(SupersetTestCase): metric_names=["new_metric1"], ) db_id = get_example_database().id - imported_id = SqlaTable.import_obj(table, db_id, import_time=1993) + imported_id = import_dataset(table, db_id, import_time=1993) copy_table = self.create_table( "copy_cat", @@ -638,14 +639,14 @@ class TestImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_id_copy = SqlaTable.import_obj(copy_table, db_id, import_time=1994) + imported_id_copy = import_dataset(copy_table, db_id, import_time=1994) self.assertEqual(imported_id, imported_id_copy) self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) def test_import_druid_no_metadata(self): datasource = self.create_druid_datasource("pure_druid", id=10001) - imported_id = DruidDatasource.import_obj(datasource, import_time=1989) + imported_id = import_dataset(datasource, import_time=1989) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) @@ -653,7 +654,7 @@ class TestImportExport(SupersetTestCase): datasource = self.create_druid_datasource( "druid_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"] ) - imported_id = DruidDatasource.import_obj(datasource, import_time=1990) + imported_id = import_dataset(datasource, import_time=1990) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) self.assertEqual( @@ -668,7 +669,7 @@ class TestImportExport(SupersetTestCase): cols_names=["c1", "c2"], metric_names=["m1", "m2"], ) - imported_id = DruidDatasource.import_obj(datasource, import_time=1991) + imported_id = import_dataset(datasource, import_time=1991) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) @@ -676,14 +677,14 @@ class TestImportExport(SupersetTestCase): datasource = self.create_druid_datasource( "druid_override", id=10004, cols_names=["col1"], metric_names=["m1"] ) - imported_id = DruidDatasource.import_obj(datasource, import_time=1991) + imported_id = import_dataset(datasource, import_time=1991) table_over = self.create_druid_datasource( "druid_override", id=10004, cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_id = DruidDatasource.import_obj(table_over, import_time=1992) + imported_over_id = import_dataset(table_over, import_time=1992) imported_over = self.get_datasource(imported_over_id) self.assertEqual(imported_id, imported_over.id) @@ -702,7 +703,7 @@ class TestImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_id = DruidDatasource.import_obj(datasource, import_time=1993) + imported_id = import_dataset(datasource, import_time=1993) copy_datasource = self.create_druid_datasource( "copy_cat", @@ -710,7 +711,7 @@ class TestImportExport(SupersetTestCase): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_id_copy = DruidDatasource.import_obj(copy_datasource, import_time=1994) + imported_id_copy = import_dataset(copy_datasource, import_time=1994) self.assertEqual(imported_id, imported_id_copy) self.assert_datasource_equals(copy_datasource, self.get_datasource(imported_id))