From 4a070cfceb85d53a166db6bc787321245de778ec Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 11 Nov 2020 14:50:18 -0800 Subject: [PATCH] chore: consolidate dashboard import logic (#11529) * Consolidate dash import logic * Fix lint * Remove Slice.import_obj * Remove unused import * Fix log --- superset/cli.py | 19 +- superset/dashboards/commands/importers/v0.py | 336 +++++++++++++++++++ superset/models/dashboard.py | 182 +--------- superset/models/slice.py | 45 +-- superset/utils/dashboard_import_export.py | 67 ---- superset/views/core.py | 9 +- tests/import_export_tests.py | 33 +- 7 files changed, 367 insertions(+), 324 deletions(-) create mode 100644 superset/dashboards/commands/importers/v0.py diff --git a/superset/cli.py b/superset/cli.py index f0f7f1e1e5..6d1d2fb92c 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -19,7 +19,7 @@ import logging from datetime import datetime, timedelta from subprocess import Popen from sys import stdout -from typing import Any, Dict, Type, Union +from typing import Any, Dict, List, Type, Union import click import yaml @@ -235,10 +235,10 @@ def refresh_druid(datasource: str, merge: bool) -> None: ) def import_dashboards(path: str, recursive: bool, username: str) -> None: """Import dashboards from JSON""" - from superset.utils import dashboard_import_export + from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand path_object = Path(path) - files = [] + files: List[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: @@ -247,14 +247,11 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None: files.extend(path_object.rglob("*.json")) if username is not None: g.user = security_manager.find_user(username=username) - for file_ in files: - logger.info("Importing dashboard from file %s", file_) - try: - with file_.open() as data_stream: - dashboard_import_export.import_dashboards(db.session, data_stream) - except Exception as ex: # pylint: disable=broad-except - logger.error("Error when importing dashboard from file %s", file_) - logger.error(ex) + contents = {path.name: open(path).read() for path in files} + try: + ImportDashboardsCommand(contents).run() + except Exception: # pylint: disable=broad-except + logger.exception("Error when importing dashboard") @superset.command() diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py new file mode 100644 index 0000000000..8040c248b8 --- /dev/null +++ b/superset/dashboards/commands/importers/v0.py @@ -0,0 +1,336 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +import time +from copy import copy +from datetime import datetime +from typing import Any, Dict, List, Optional + +from flask_babel import lazy_gettext as _ +from sqlalchemy.orm import make_transient, Session + +from superset import ConnectorRegistry, db +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.exceptions import DashboardImportException +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.utils.dashboard_filter_scopes_converter import ( + convert_filter_scopes, + copy_filter_scopes, +) + +logger = logging.getLogger(__name__) + + +def import_chart( + slc_to_import: Slice, + slc_to_override: Optional[Slice], + import_time: Optional[int] = None, +) -> int: + """Inserts or overrides slc in the database. + + remote_id and import_time fields in params_dict are set to track the + slice origin and ensure correct overrides for multiple imports. + Slice.perm is used to find the datasources and connect them. + + :param Slice slc_to_import: Slice object to import + :param Slice slc_to_override: Slice to replace, id matches remote_id + :returns: The resulting id for the imported slice + :rtype: int + """ + session = db.session + make_transient(slc_to_import) + slc_to_import.dashboards = [] + slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) + + slc_to_import = slc_to_import.copy() + slc_to_import.reset_ownership() + params = slc_to_import.params_dict + datasource = ConnectorRegistry.get_datasource_by_name( + session, + slc_to_import.datasource_type, + params["datasource_name"], + params["schema"], + params["database_name"], + ) + slc_to_import.datasource_id = datasource.id # type: ignore + if slc_to_override: + slc_to_override.override(slc_to_import) + session.flush() + return slc_to_override.id + session.add(slc_to_import) + logger.info("Final slice: %s", str(slc_to_import.to_json())) + session.flush() + return slc_to_import.id + + +def import_dashboard( + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + dashboard_to_import: Dashboard, + import_time: Optional[int] = None, +) -> int: + """Imports the dashboard from the object to the database. + + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this + dashboard will be wired to existing tables. This function can be used + to import/export dashboards between multiple superset instances. + Audit metadata isn't copied over. + """ + + def alter_positions( + dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int] + ) -> None: + """Updates slice_ids in the position json. + + Sample position_json data: + { + "DASHBOARD_VERSION_KEY": "v2", + "DASHBOARD_ROOT_ID": { + "type": "DASHBOARD_ROOT_TYPE", + "id": "DASHBOARD_ROOT_ID", + "children": ["DASHBOARD_GRID_ID"] + }, + "DASHBOARD_GRID_ID": { + "type": "DASHBOARD_GRID_TYPE", + "id": "DASHBOARD_GRID_ID", + "children": ["DASHBOARD_CHART_TYPE-2"] + }, + "DASHBOARD_CHART_TYPE-2": { + "type": "CHART", + "id": "DASHBOARD_CHART_TYPE-2", + "children": [], + "meta": { + "width": 4, + "height": 50, + "chartId": 118 + } + }, + } + """ + position_data = json.loads(dashboard.position_json) + position_json = position_data.values() + for value in position_json: + if ( + isinstance(value, dict) + and value.get("meta") + and value.get("meta", {}).get("chartId") + ): + old_slice_id = value["meta"]["chartId"] + + if old_slice_id in old_to_new_slc_id_dict: + value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id] + dashboard.position_json = json.dumps(position_data) + + logger.info("Started import of the dashboard: %s", dashboard_to_import.to_json()) + session = db.session + logger.info("Dashboard has %d slices", len(dashboard_to_import.slices)) + # copy slices object as Slice.import_slice will mutate the slice + # and will remove the existing dashboard - slice association + slices = copy(dashboard_to_import.slices) + + # Clearing the slug to avoid conflicts + dashboard_to_import.slug = None + + old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}") + old_to_new_slc_id_dict: Dict[int, int] = {} + new_timed_refresh_immune_slices = [] + new_expanded_slices = {} + new_filter_scopes = {} + i_params_dict = dashboard_to_import.params_dict + remote_id_slice_map = { + slc.params_dict["remote_id"]: slc + for slc in session.query(Slice).all() + if "remote_id" in slc.params_dict + } + for slc in slices: + logger.info( + "Importing slice %s from the dashboard: %s", + slc.to_json(), + dashboard_to_import.dashboard_title, + ) + remote_slc = remote_id_slice_map.get(slc.id) + new_slc_id = import_chart(slc, remote_slc, import_time=import_time) + old_to_new_slc_id_dict[slc.id] = new_slc_id + # update json metadata that deals with slice ids + new_slc_id_str = str(new_slc_id) + old_slc_id_str = str(slc.id) + if ( + "timed_refresh_immune_slices" in i_params_dict + and old_slc_id_str in i_params_dict["timed_refresh_immune_slices"] + ): + new_timed_refresh_immune_slices.append(new_slc_id_str) + if ( + "expanded_slices" in i_params_dict + and old_slc_id_str in i_params_dict["expanded_slices"] + ): + new_expanded_slices[new_slc_id_str] = i_params_dict["expanded_slices"][ + old_slc_id_str + ] + + # since PR #9109, filter_immune_slices and filter_immune_slice_fields + # are converted to filter_scopes + # but dashboard create from import may still have old dashboard filter metadata + # here we convert them to new filter_scopes metadata first + filter_scopes = {} + if ( + "filter_immune_slices" in i_params_dict + or "filter_immune_slice_fields" in i_params_dict + ): + filter_scopes = convert_filter_scopes(old_json_metadata, slices) + + if "filter_scopes" in i_params_dict: + filter_scopes = old_json_metadata.get("filter_scopes") + + # then replace old slice id to new slice id: + if filter_scopes: + new_filter_scopes = copy_filter_scopes( + old_to_new_slc_id_dict=old_to_new_slc_id_dict, + old_filter_scopes=filter_scopes, + ) + + # override the dashboard + existing_dashboard = None + for dash in session.query(Dashboard).all(): + if ( + "remote_id" in dash.params_dict + and dash.params_dict["remote_id"] == dashboard_to_import.id + ): + existing_dashboard = dash + + dashboard_to_import = dashboard_to_import.copy() + dashboard_to_import.id = None + dashboard_to_import.reset_ownership() + # position_json can be empty for dashboards + # with charts added from chart-edit page and without re-arranging + if dashboard_to_import.position_json: + alter_positions(dashboard_to_import, old_to_new_slc_id_dict) + dashboard_to_import.alter_params(import_time=import_time) + dashboard_to_import.remove_params(param_to_remove="filter_immune_slices") + dashboard_to_import.remove_params(param_to_remove="filter_immune_slice_fields") + if new_filter_scopes: + dashboard_to_import.alter_params(filter_scopes=new_filter_scopes) + if new_expanded_slices: + dashboard_to_import.alter_params(expanded_slices=new_expanded_slices) + if new_timed_refresh_immune_slices: + dashboard_to_import.alter_params( + timed_refresh_immune_slices=new_timed_refresh_immune_slices + ) + + new_slices = ( + session.query(Slice).filter(Slice.id.in_(old_to_new_slc_id_dict.values())).all() + ) + + if existing_dashboard: + existing_dashboard.override(dashboard_to_import) + existing_dashboard.slices = new_slices + session.flush() + return existing_dashboard.id + + dashboard_to_import.slices = new_slices + session.add(dashboard_to_import) + session.flush() + return dashboard_to_import.id # type: ignore + + +def decode_dashboards( # pylint: disable=too-many-return-statements + o: Dict[str, Any] +) -> Any: + """ + Function to be passed into json.loads obj_hook parameter + Recreates the dashboard object from a json representation. + """ + from superset.connectors.druid.models import ( + DruidCluster, + DruidColumn, + DruidDatasource, + DruidMetric, + ) + + if "__Dashboard__" in o: + return Dashboard(**o["__Dashboard__"]) + if "__Slice__" in o: + return Slice(**o["__Slice__"]) + if "__TableColumn__" in o: + return TableColumn(**o["__TableColumn__"]) + if "__SqlaTable__" in o: + return SqlaTable(**o["__SqlaTable__"]) + if "__SqlMetric__" in o: + return SqlMetric(**o["__SqlMetric__"]) + if "__DruidCluster__" in o: + return DruidCluster(**o["__DruidCluster__"]) + if "__DruidColumn__" in o: + return DruidColumn(**o["__DruidColumn__"]) + if "__DruidDatasource__" in o: + return DruidDatasource(**o["__DruidDatasource__"]) + if "__DruidMetric__" in o: + return DruidMetric(**o["__DruidMetric__"]) + if "__datetime__" in o: + return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") + + return o + + +def import_dashboards( + session: Session, + content: str, + database_id: Optional[int] = None, + import_time: Optional[int] = None, +) -> None: + """Imports dashboards from a stream to databases""" + current_tt = int(time.time()) + import_time = current_tt if import_time is None else import_time + data = json.loads(content, object_hook=decode_dashboards) + if not data: + raise DashboardImportException(_("No data in file")) + for table in data["datasources"]: + type(table).import_obj(table, database_id, import_time=import_time) + session.commit() + for dashboard in data["dashboards"]: + import_dashboard(dashboard, import_time=import_time) + session.commit() + + +class ImportDashboardsCommand(BaseCommand): + """ + Import dashboard in JSON format. + + This is the original unversioned format used to export and import dashboards + in Superset. + """ + + def __init__(self, contents: Dict[str, str], database_id: Optional[int] = None): + self.contents = contents + self.database_id = database_id + + def run(self) -> None: + self.validate() + + for file_name, content in self.contents.items(): + logger.info("Importing dashboard from file %s", file_name) + import_dashboards(db.session, content, self.database_id) + + def validate(self) -> None: + # ensure all files are JSON + for content in self.contents.values(): + try: + json.loads(content) + except ValueError: + raise diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 3ef50cde82..413e83c3f7 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -16,10 +16,9 @@ # under the License. import json import logging -from copy import copy from functools import partial from json.decoder import JSONDecodeError -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Set, Union from urllib import parse import sqlalchemy as sqla @@ -61,10 +60,6 @@ from superset.models.tags import DashboardUpdater from superset.models.user_attributes import UserAttribute from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.utils import core as utils -from superset.utils.dashboard_filter_scopes_converter import ( - convert_filter_scopes, - copy_filter_scopes, -) from superset.utils.decorators import debounce from superset.utils.urls import get_url_path @@ -321,181 +316,6 @@ class Dashboard( # pylint: disable=too-many-instance-attributes for (dashboard_id,) in db.engine.execute(filter_query): cls(id=dashboard_id).clear_cache() - @classmethod - def import_obj( - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - cls, - dashboard_to_import: "Dashboard", - import_time: Optional[int] = None, - ) -> int: - """Imports the dashboard from the object to the database. - - Once dashboard is imported, json_metadata field is extended and stores - remote_id and import_time. It helps to decide if the dashboard has to - be overridden or just copies over. Slices that belong to this - dashboard will be wired to existing tables. This function can be used - to import/export dashboards between multiple superset instances. - Audit metadata isn't copied over. - """ - - def alter_positions( - dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int] - ) -> None: - """Updates slice_ids in the position json. - - Sample position_json data: - { - "DASHBOARD_VERSION_KEY": "v2", - "DASHBOARD_ROOT_ID": { - "type": "DASHBOARD_ROOT_TYPE", - "id": "DASHBOARD_ROOT_ID", - "children": ["DASHBOARD_GRID_ID"] - }, - "DASHBOARD_GRID_ID": { - "type": "DASHBOARD_GRID_TYPE", - "id": "DASHBOARD_GRID_ID", - "children": ["DASHBOARD_CHART_TYPE-2"] - }, - "DASHBOARD_CHART_TYPE-2": { - "type": "CHART", - "id": "DASHBOARD_CHART_TYPE-2", - "children": [], - "meta": { - "width": 4, - "height": 50, - "chartId": 118 - } - }, - } - """ - position_data = json.loads(dashboard.position_json) - position_json = position_data.values() - for value in position_json: - if ( - isinstance(value, dict) - and value.get("meta") - and value.get("meta", {}).get("chartId") - ): - old_slice_id = value["meta"]["chartId"] - - if old_slice_id in old_to_new_slc_id_dict: - value["meta"]["chartId"] = old_to_new_slc_id_dict[old_slice_id] - dashboard.position_json = json.dumps(position_data) - - logger.info( - "Started import of the dashboard: %s", dashboard_to_import.to_json() - ) - session = db.session - logger.info("Dashboard has %d slices", len(dashboard_to_import.slices)) - # copy slices object as Slice.import_slice will mutate the slice - # and will remove the existing dashboard - slice association - slices = copy(dashboard_to_import.slices) - - # Clearing the slug to avoid conflicts - dashboard_to_import.slug = None - - old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}") - old_to_new_slc_id_dict: Dict[int, int] = {} - new_timed_refresh_immune_slices = [] - new_expanded_slices = {} - new_filter_scopes = {} - i_params_dict = dashboard_to_import.params_dict - remote_id_slice_map = { - slc.params_dict["remote_id"]: slc - for slc in session.query(Slice).all() - if "remote_id" in slc.params_dict - } - for slc in slices: - logger.info( - "Importing slice %s from the dashboard: %s", - slc.to_json(), - dashboard_to_import.dashboard_title, - ) - remote_slc = remote_id_slice_map.get(slc.id) - new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time) - old_to_new_slc_id_dict[slc.id] = new_slc_id - # update json metadata that deals with slice ids - new_slc_id_str = str(new_slc_id) - old_slc_id_str = str(slc.id) - if ( - "timed_refresh_immune_slices" in i_params_dict - and old_slc_id_str in i_params_dict["timed_refresh_immune_slices"] - ): - new_timed_refresh_immune_slices.append(new_slc_id_str) - if ( - "expanded_slices" in i_params_dict - and old_slc_id_str in i_params_dict["expanded_slices"] - ): - new_expanded_slices[new_slc_id_str] = i_params_dict["expanded_slices"][ - old_slc_id_str - ] - - # since PR #9109, filter_immune_slices and filter_immune_slice_fields - # are converted to filter_scopes - # but dashboard create from import may still have old dashboard filter metadata - # here we convert them to new filter_scopes metadata first - filter_scopes = {} - if ( - "filter_immune_slices" in i_params_dict - or "filter_immune_slice_fields" in i_params_dict - ): - filter_scopes = convert_filter_scopes(old_json_metadata, slices) - - if "filter_scopes" in i_params_dict: - filter_scopes = old_json_metadata.get("filter_scopes") - - # then replace old slice id to new slice id: - if filter_scopes: - new_filter_scopes = copy_filter_scopes( - old_to_new_slc_id_dict=old_to_new_slc_id_dict, - old_filter_scopes=filter_scopes, - ) - - # override the dashboard - existing_dashboard = None - for dash in session.query(Dashboard).all(): - if ( - "remote_id" in dash.params_dict - and dash.params_dict["remote_id"] == dashboard_to_import.id - ): - existing_dashboard = dash - - dashboard_to_import = dashboard_to_import.copy() - dashboard_to_import.id = None - dashboard_to_import.reset_ownership() - # position_json can be empty for dashboards - # with charts added from chart-edit page and without re-arranging - if dashboard_to_import.position_json: - alter_positions(dashboard_to_import, old_to_new_slc_id_dict) - dashboard_to_import.alter_params(import_time=import_time) - dashboard_to_import.remove_params(param_to_remove="filter_immune_slices") - dashboard_to_import.remove_params(param_to_remove="filter_immune_slice_fields") - if new_filter_scopes: - dashboard_to_import.alter_params(filter_scopes=new_filter_scopes) - if new_expanded_slices: - dashboard_to_import.alter_params(expanded_slices=new_expanded_slices) - if new_timed_refresh_immune_slices: - dashboard_to_import.alter_params( - timed_refresh_immune_slices=new_timed_refresh_immune_slices - ) - - new_slices = ( - session.query(Slice) - .filter(Slice.id.in_(old_to_new_slc_id_dict.values())) - .all() - ) - - if existing_dashboard: - existing_dashboard.override(dashboard_to_import) - existing_dashboard.slices = new_slices - session.flush() - return existing_dashboard.id - - dashboard_to_import.slices = new_slices - session.add(dashboard_to_import) - session.flush() - return dashboard_to_import.id # type: ignore - @classmethod def export_dashboards( # pylint: disable=too-many-locals cls, dashboard_ids: List[int] diff --git a/superset/models/slice.py b/superset/models/slice.py index 4ed8ba41fb..87b56a3c25 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -25,7 +25,7 @@ from flask_appbuilder.models.decorators import renders from markupsafe import escape, Markup from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import make_transient, relationship +from sqlalchemy.orm import relationship from sqlalchemy.orm.mapper import Mapper from superset import ConnectorRegistry, db, is_feature_enabled, security_manager @@ -289,49 +289,6 @@ class Slice( """ - @classmethod - def import_obj( - cls, - slc_to_import: "Slice", - slc_to_override: Optional["Slice"], - import_time: Optional[int] = None, - ) -> int: - """Inserts or overrides slc in the database. - - remote_id and import_time fields in params_dict are set to track the - slice origin and ensure correct overrides for multiple imports. - Slice.perm is used to find the datasources and connect them. - - :param Slice slc_to_import: Slice object to import - :param Slice slc_to_override: Slice to replace, id matches remote_id - :returns: The resulting id for the imported slice - :rtype: int - """ - session = db.session - make_transient(slc_to_import) - slc_to_import.dashboards = [] - slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) - - slc_to_import = slc_to_import.copy() - slc_to_import.reset_ownership() - params = slc_to_import.params_dict - datasource = ConnectorRegistry.get_datasource_by_name( - session, - slc_to_import.datasource_type, - params["datasource_name"], - params["schema"], - params["database_name"], - ) - slc_to_import.datasource_id = datasource.id # type: ignore - if slc_to_override: - slc_to_override.override(slc_to_import) - session.flush() - return slc_to_override.id - session.add(slc_to_import) - logger.info("Final slice: %s", str(slc_to_import.to_json())) - session.flush() - return slc_to_import.id - @property def url(self) -> str: return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D" diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index 6ae500b40f..fc61d0a422 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -14,82 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json import logging -import time -from datetime import datetime -from io import BytesIO -from typing import Any, Dict, Optional -from flask_babel import lazy_gettext as _ from sqlalchemy.orm import Session -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn -from superset.exceptions import DashboardImportException from superset.models.dashboard import Dashboard -from superset.models.slice import Slice logger = logging.getLogger(__name__) -def decode_dashboards( # pylint: disable=too-many-return-statements - o: Dict[str, Any] -) -> Any: - """ - Function to be passed into json.loads obj_hook parameter - Recreates the dashboard object from a json representation. - """ - from superset.connectors.druid.models import ( - DruidCluster, - DruidColumn, - DruidDatasource, - DruidMetric, - ) - - if "__Dashboard__" in o: - return Dashboard(**o["__Dashboard__"]) - if "__Slice__" in o: - return Slice(**o["__Slice__"]) - if "__TableColumn__" in o: - return TableColumn(**o["__TableColumn__"]) - if "__SqlaTable__" in o: - return SqlaTable(**o["__SqlaTable__"]) - if "__SqlMetric__" in o: - return SqlMetric(**o["__SqlMetric__"]) - if "__DruidCluster__" in o: - return DruidCluster(**o["__DruidCluster__"]) - if "__DruidColumn__" in o: - return DruidColumn(**o["__DruidColumn__"]) - if "__DruidDatasource__" in o: - return DruidDatasource(**o["__DruidDatasource__"]) - if "__DruidMetric__" in o: - return DruidMetric(**o["__DruidMetric__"]) - if "__datetime__" in o: - return datetime.strptime(o["__datetime__"], "%Y-%m-%dT%H:%M:%S") - - return o - - -def import_dashboards( - session: Session, - data_stream: BytesIO, - database_id: Optional[int] = None, - import_time: Optional[int] = None, -) -> None: - """Imports dashboards from a stream to databases""" - current_tt = int(time.time()) - import_time = current_tt if import_time is None else import_time - data = json.loads(data_stream.read(), object_hook=decode_dashboards) - if not data: - raise DashboardImportException(_("No data in file")) - for table in data["datasources"]: - type(table).import_obj(table, database_id, import_time=import_time) - session.commit() - for dashboard in data["dashboards"]: - Dashboard.import_obj(dashboard, import_time=import_time) - session.commit() - - def export_dashboards(session: Session) -> str: """Returns all dashboards metadata as a json dump""" logger.info("Starting export") diff --git a/superset/views/core.py b/superset/views/core.py index 97eef1fdff..9ff28fc1ec 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -61,6 +61,7 @@ from superset.connectors.sqla.models import ( SqlMetric, TableColumn, ) +from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand from superset.dashboards.dao import DashboardDAO from superset.databases.filters import DatabaseFilter from superset.exceptions import ( @@ -86,7 +87,7 @@ from superset.security.analytics_db_safety import ( from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sql_validators import get_validator_by_name from superset.typing import FlaskResponse -from superset.utils import core as utils, dashboard_import_export +from superset.utils import core as utils from superset.utils.dates import now_as_float from superset.utils.decorators import etag_cache from superset.views.base import ( @@ -545,9 +546,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods success = False database_id = request.form.get("db_id") try: - dashboard_import_export.import_dashboards( - db.session, import_file.stream, database_id - ) + ImportDashboardsCommand( + {import_file.filename: import_file.read()}, database_id + ).run() success = True except DatabaseNotFound as ex: logger.exception(ex) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 0ca8dbdb91..c19b957cb1 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -23,7 +23,7 @@ from flask import g from sqlalchemy.orm.session import make_transient from tests.test_app import app -from superset.utils.dashboard_import_export import decode_dashboards +from superset.dashboards.commands.importers.v0 import decode_dashboards from superset import db, security_manager from superset.connectors.druid.models import ( DruidColumn, @@ -32,6 +32,7 @@ from superset.connectors.druid.models import ( DruidCluster, ) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils.core import get_example_database @@ -306,7 +307,7 @@ class TestImportExport(SupersetTestCase): def test_import_1_slice(self): expected_slice = self.create_slice("Import Me", id=10001) - slc_id = Slice.import_obj(expected_slice, None, import_time=1989) + slc_id = import_chart(expected_slice, None, import_time=1989) slc = self.get_slice(slc_id) self.assertEqual(slc.datasource.perm, slc.perm) self.assert_slice_equals(expected_slice, slc) @@ -318,9 +319,9 @@ class TestImportExport(SupersetTestCase): table_id = self.get_table_by_name("wb_health_population").id # table_id != 666, import func will have to find the table slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002) - slc_id_1 = Slice.import_obj(slc_1, None) + slc_id_1 = import_chart(slc_1, None) slc_2 = self.create_slice("Import Me 2", ds_id=666, id=10003) - slc_id_2 = Slice.import_obj(slc_2, None) + slc_id_2 = import_chart(slc_2, None) imported_slc_1 = self.get_slice(slc_id_1) imported_slc_2 = self.get_slice(slc_id_2) @@ -334,25 +335,25 @@ class TestImportExport(SupersetTestCase): def test_import_slices_for_non_existent_table(self): with self.assertRaises(AttributeError): - Slice.import_obj( + import_chart( self.create_slice("Import Me 3", id=10004, table_name="non_existent"), None, ) def test_import_slices_override(self): slc = self.create_slice("Import Me New", id=10005) - slc_1_id = Slice.import_obj(slc, None, import_time=1990) + slc_1_id = import_chart(slc, None, import_time=1990) slc.slice_name = "Import Me New" imported_slc_1 = self.get_slice(slc_1_id) slc_2 = self.create_slice("Import Me New", id=10005) - slc_2_id = Slice.import_obj(slc_2, imported_slc_1, import_time=1990) + slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990) self.assertEqual(slc_1_id, slc_2_id) imported_slc_2 = self.get_slice(slc_2_id) self.assert_slice_equals(slc, imported_slc_2) def test_import_empty_dashboard(self): empty_dash = self.create_dashboard("empty_dashboard", id=10001) - imported_dash_id = Dashboard.import_obj(empty_dash, import_time=1989) + imported_dash_id = import_dashboard(empty_dash, import_time=1989) imported_dash = self.get_dash(imported_dash_id) self.assert_dash_equals(empty_dash, imported_dash, check_position=False) @@ -377,7 +378,7 @@ class TestImportExport(SupersetTestCase): """.format( slc.id ) - imported_dash_id = Dashboard.import_obj(dash_with_1_slice, import_time=1990) + imported_dash_id = import_dashboard(dash_with_1_slice, import_time=1990) imported_dash = self.get_dash(imported_dash_id) expected_dash = self.create_dashboard("dash_with_1_slice", slcs=[slc], id=10002) @@ -419,7 +420,7 @@ class TestImportExport(SupersetTestCase): } ) - imported_dash_id = Dashboard.import_obj(dash_with_2_slices, import_time=1991) + imported_dash_id = import_dashboard(dash_with_2_slices, import_time=1991) imported_dash = self.get_dash(imported_dash_id) expected_dash = self.create_dashboard( @@ -454,7 +455,7 @@ class TestImportExport(SupersetTestCase): dash_to_import = self.create_dashboard( "override_dashboard", slcs=[e_slc, b_slc], id=10004 ) - imported_dash_id_1 = Dashboard.import_obj(dash_to_import, import_time=1992) + imported_dash_id_1 = import_dashboard(dash_to_import, import_time=1992) # create new instances of the slices e_slc = self.create_slice("e_slc", id=10009, table_name="energy_usage") @@ -463,9 +464,7 @@ class TestImportExport(SupersetTestCase): dash_to_import_override = self.create_dashboard( "override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004 ) - imported_dash_id_2 = Dashboard.import_obj( - dash_to_import_override, import_time=1992 - ) + imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992) # override doesn't change the id self.assertEqual(imported_dash_id_1, imported_dash_id_2) @@ -495,7 +494,7 @@ class TestImportExport(SupersetTestCase): dash_with_1_slice.changed_by = admin_user dash_with_1_slice.owners = [admin_user] - imported_dash_id = Dashboard.import_obj(dash_with_1_slice) + imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) self.assertEqual(imported_dash.created_by, gamma_user) self.assertEqual(imported_dash.changed_by, gamma_user) @@ -515,7 +514,7 @@ class TestImportExport(SupersetTestCase): dash_with_1_slice = self._create_dashboard_for_import(id_=10300) - imported_dash_id = Dashboard.import_obj(dash_with_1_slice) + imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) self.assertEqual(imported_dash.created_by, gamma_user) self.assertEqual(imported_dash.changed_by, gamma_user) @@ -531,7 +530,7 @@ class TestImportExport(SupersetTestCase): dash_with_1_slice = self._create_dashboard_for_import(id_=10300) - imported_dash_id = Dashboard.import_obj(dash_with_1_slice) + imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) self.assertEqual(imported_dash.created_by, gamma_user) self.assertEqual(imported_dash.changed_by, gamma_user)