From 6b2b208b3b5d61573d2abce15326ed8b055da07b Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 5 Jan 2021 17:52:42 -0800 Subject: [PATCH] fix: load example data into correct DB (#12292) * fix: load example data into correct DB * Fix force_data * Fix lint --- superset/cli.py | 2 +- superset/commands/importers/v1/examples.py | 26 ++++++- .../datasets/commands/importers/v1/utils.py | 68 ++++++++++++------- superset/examples/utils.py | 4 +- 4 files changed, 70 insertions(+), 30 deletions(-) diff --git a/superset/cli.py b/superset/cli.py index 72b02d5b1b..557cc9d0ef 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -168,7 +168,7 @@ def load_examples_run( examples.load_tabbed_dashboard(only_metadata) # load examples that are stored as YAML config files - examples.load_from_configs() + examples.load_from_configs(force) @with_appcontext diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index c5b2e6e023..2b56ee080d 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -55,10 +55,28 @@ class ImportExamplesCommand(ImportModelsCommand): } import_error = CommandException - # pylint: disable=too-many-locals + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + super().__init__(contents, *args, **kwargs) + self.force_data = kwargs.get("force_data", False) + + def run(self) -> None: + self.validate() + + # rollback to prevent partial imports + try: + self._import(db.session, self._configs, self.overwrite, self.force_data) + db.session.commit() + except Exception: + db.session.rollback() + raise self.import_error() + + # pylint: disable=too-many-locals, arguments-differ @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, + configs: Dict[str, Any], + overwrite: bool = False, + force_data: bool = False, ) -> None: # import databases database_ids: Dict[str, int] = {} @@ -78,7 +96,9 @@ class ImportExamplesCommand(ImportModelsCommand): for file_name, config in configs.items(): if file_name.startswith("datasets/"): config["database_id"] = examples_id - dataset = import_dataset(session, config, overwrite=overwrite) + dataset = import_dataset( + session, config, overwrite=overwrite, force_data=force_data + ) dataset_info[str(dataset.uuid)] = { "datasource_id": dataset.id, "datasource_type": "view" if dataset.is_sqllab_view else "table", diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index d1cd0bb991..5e59545ae7 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -29,6 +29,8 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.visitors import VisitableType from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database +from superset.utils.core import get_example_database, get_main_database logger = logging.getLogger(__name__) @@ -74,7 +76,10 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]: def import_dataset( - session: Session, config: Dict[str, Any], overwrite: bool = False + session: Session, + config: Dict[str, Any], + overwrite: bool = False, + force_data: bool = False, ) -> SqlaTable: existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first() if existing: @@ -108,28 +113,43 @@ def import_dataset( if dataset.id is None: session.flush() - # load data - if data_uri: - data = request.urlopen(data_uri) - if data_uri.endswith(".gz"): - data = gzip.open(data) - df = pd.read_csv(data, encoding="utf-8") - dtype = get_dtype(df, dataset) - - # convert temporal columns - for column_name, sqla_type in dtype.items(): - if isinstance(sqla_type, (Date, DateTime)): - df[column_name] = pd.to_datetime(df[column_name]) - - df.to_sql( - dataset.table_name, - con=session.connection(), - schema=dataset.schema, - if_exists="replace", - chunksize=CHUNKSIZE, - dtype=dtype, - index=False, - method="multi", - ) + example_database = get_example_database() + table_exists = example_database.has_table_by_name(dataset.table_name) + if data_uri and (not table_exists or force_data): + load_data(data_uri, dataset, example_database, session) return dataset + + +def load_data( + data_uri: str, dataset: SqlaTable, example_database: Database, session: Session +) -> None: + data = request.urlopen(data_uri) + if data_uri.endswith(".gz"): + data = gzip.open(data) + df = pd.read_csv(data, encoding="utf-8") + dtype = get_dtype(df, dataset) + + # convert temporal columns + for column_name, sqla_type in dtype.items(): + if isinstance(sqla_type, (Date, DateTime)): + df[column_name] = pd.to_datetime(df[column_name]) + + # reuse session when loading data if possible, to make import atomic + if example_database.sqlalchemy_uri == get_main_database().sqlalchemy_uri: + logger.info("Loading data inside the import transaction") + connection = session.connection() + else: + logger.warning("Loading data outside the import transaction") + connection = example_database.get_sqla_engine() + + df.to_sql( + dataset.table_name, + con=connection, + schema=dataset.schema, + if_exists="replace", + chunksize=CHUNKSIZE, + dtype=dtype, + index=False, + method="multi", + ) diff --git a/superset/examples/utils.py b/superset/examples/utils.py index 723f2bceca..951b741b28 100644 --- a/superset/examples/utils.py +++ b/superset/examples/utils.py @@ -24,9 +24,9 @@ from superset.commands.importers.v1.examples import ImportExamplesCommand YAML_EXTENSIONS = {".yaml", ".yml"} -def load_from_configs() -> None: +def load_from_configs(force_data: bool = False) -> None: contents = load_contents() - command = ImportExamplesCommand(contents, overwrite=True) + command = ImportExamplesCommand(contents, overwrite=True, force_data=force_data) command.run()