fix: load example data into correct DB (#12292)

* fix: load example data into correct DB

* Fix force_data

* Fix lint
This commit is contained in:
Beto Dealmeida 2021-01-05 17:52:42 -08:00 committed by GitHub
parent b4f6d353c9
commit 6b2b208b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 30 deletions

View File

@ -168,7 +168,7 @@ def load_examples_run(
examples.load_tabbed_dashboard(only_metadata)
# load examples that are stored as YAML config files
examples.load_from_configs()
examples.load_from_configs(force)
@with_appcontext

View File

@ -55,10 +55,28 @@ class ImportExamplesCommand(ImportModelsCommand):
}
import_error = CommandException
# pylint: disable=too-many-locals
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import(db.session, self._configs, self.overwrite, self.force_data)
db.session.commit()
except Exception:
db.session.rollback()
raise self.import_error()
# pylint: disable=too-many-locals, arguments-differ
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
session: Session,
configs: Dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> None:
# import databases
database_ids: Dict[str, int] = {}
@ -78,7 +96,9 @@ class ImportExamplesCommand(ImportModelsCommand):
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = examples_id
dataset = import_dataset(session, config, overwrite=overwrite)
dataset = import_dataset(
session, config, overwrite=overwrite, force_data=force_data
)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": "view" if dataset.is_sqllab_view else "table",

View File

@ -29,6 +29,8 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.visitors import VisitableType
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
logger = logging.getLogger(__name__)
@ -74,7 +76,10 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
def import_dataset(
session: Session, config: Dict[str, Any], overwrite: bool = False
session: Session,
config: Dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> SqlaTable:
existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
if existing:
@ -108,28 +113,43 @@ def import_dataset(
if dataset.id is None:
session.flush()
# load data
if data_uri:
data = request.urlopen(data_uri)
if data_uri.endswith(".gz"):
data = gzip.open(data)
df = pd.read_csv(data, encoding="utf-8")
dtype = get_dtype(df, dataset)
# convert temporal columns
for column_name, sqla_type in dtype.items():
if isinstance(sqla_type, (Date, DateTime)):
df[column_name] = pd.to_datetime(df[column_name])
df.to_sql(
dataset.table_name,
con=session.connection(),
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)
example_database = get_example_database()
table_exists = example_database.has_table_by_name(dataset.table_name)
if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, example_database, session)
return dataset
def load_data(
data_uri: str, dataset: SqlaTable, example_database: Database, session: Session
) -> None:
data = request.urlopen(data_uri)
if data_uri.endswith(".gz"):
data = gzip.open(data)
df = pd.read_csv(data, encoding="utf-8")
dtype = get_dtype(df, dataset)
# convert temporal columns
for column_name, sqla_type in dtype.items():
if isinstance(sqla_type, (Date, DateTime)):
df[column_name] = pd.to_datetime(df[column_name])
# reuse session when loading data if possible, to make import atomic
if example_database.sqlalchemy_uri == get_main_database().sqlalchemy_uri:
logger.info("Loading data inside the import transaction")
connection = session.connection()
else:
logger.warning("Loading data outside the import transaction")
connection = example_database.get_sqla_engine()
df.to_sql(
dataset.table_name,
con=connection,
schema=dataset.schema,
if_exists="replace",
chunksize=CHUNKSIZE,
dtype=dtype,
index=False,
method="multi",
)

View File

@ -24,9 +24,9 @@ from superset.commands.importers.v1.examples import ImportExamplesCommand
YAML_EXTENSIONS = {".yaml", ".yml"}
def load_from_configs() -> None:
def load_from_configs(force_data: bool = False) -> None:
contents = load_contents()
command = ImportExamplesCommand(contents, overwrite=True)
command = ImportExamplesCommand(contents, overwrite=True, force_data=force_data)
command.run()