mirror of https://github.com/apache/superset.git
fix: load example data into correct DB (#12292)
* fix: load example data into correct DB * Fix force_data * Fix lint
This commit is contained in:
parent
b4f6d353c9
commit
6b2b208b3b
|
@ -168,7 +168,7 @@ def load_examples_run(
|
|||
examples.load_tabbed_dashboard(only_metadata)
|
||||
|
||||
# load examples that are stored as YAML config files
|
||||
examples.load_from_configs()
|
||||
examples.load_from_configs(force)
|
||||
|
||||
|
||||
@with_appcontext
|
||||
|
|
|
@ -55,10 +55,28 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
}
|
||||
import_error = CommandException
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
super().__init__(contents, *args, **kwargs)
|
||||
self.force_data = kwargs.get("force_data", False)
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
||||
# rollback to prevent partial imports
|
||||
try:
|
||||
self._import(db.session, self._configs, self.overwrite, self.force_data)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise self.import_error()
|
||||
|
||||
# pylint: disable=too-many-locals, arguments-differ
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: Dict[str, Any], overwrite: bool = False
|
||||
session: Session,
|
||||
configs: Dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
force_data: bool = False,
|
||||
) -> None:
|
||||
# import databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
|
@ -78,7 +96,9 @@ class ImportExamplesCommand(ImportModelsCommand):
|
|||
for file_name, config in configs.items():
|
||||
if file_name.startswith("datasets/"):
|
||||
config["database_id"] = examples_id
|
||||
dataset = import_dataset(session, config, overwrite=overwrite)
|
||||
dataset = import_dataset(
|
||||
session, config, overwrite=overwrite, force_data=force_data
|
||||
)
|
||||
dataset_info[str(dataset.uuid)] = {
|
||||
"datasource_id": dataset.id,
|
||||
"datasource_type": "view" if dataset.is_sqllab_view else "table",
|
||||
|
|
|
@ -29,6 +29,8 @@ from sqlalchemy.orm import Session
|
|||
from sqlalchemy.sql.visitors import VisitableType
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_database, get_main_database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -74,7 +76,10 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
|
|||
|
||||
|
||||
def import_dataset(
|
||||
session: Session, config: Dict[str, Any], overwrite: bool = False
|
||||
session: Session,
|
||||
config: Dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
force_data: bool = False,
|
||||
) -> SqlaTable:
|
||||
existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
|
||||
if existing:
|
||||
|
@ -108,28 +113,43 @@ def import_dataset(
|
|||
if dataset.id is None:
|
||||
session.flush()
|
||||
|
||||
# load data
|
||||
if data_uri:
|
||||
data = request.urlopen(data_uri)
|
||||
if data_uri.endswith(".gz"):
|
||||
data = gzip.open(data)
|
||||
df = pd.read_csv(data, encoding="utf-8")
|
||||
dtype = get_dtype(df, dataset)
|
||||
|
||||
# convert temporal columns
|
||||
for column_name, sqla_type in dtype.items():
|
||||
if isinstance(sqla_type, (Date, DateTime)):
|
||||
df[column_name] = pd.to_datetime(df[column_name])
|
||||
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=session.connection(),
|
||||
schema=dataset.schema,
|
||||
if_exists="replace",
|
||||
chunksize=CHUNKSIZE,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
example_database = get_example_database()
|
||||
table_exists = example_database.has_table_by_name(dataset.table_name)
|
||||
if data_uri and (not table_exists or force_data):
|
||||
load_data(data_uri, dataset, example_database, session)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_data(
|
||||
data_uri: str, dataset: SqlaTable, example_database: Database, session: Session
|
||||
) -> None:
|
||||
data = request.urlopen(data_uri)
|
||||
if data_uri.endswith(".gz"):
|
||||
data = gzip.open(data)
|
||||
df = pd.read_csv(data, encoding="utf-8")
|
||||
dtype = get_dtype(df, dataset)
|
||||
|
||||
# convert temporal columns
|
||||
for column_name, sqla_type in dtype.items():
|
||||
if isinstance(sqla_type, (Date, DateTime)):
|
||||
df[column_name] = pd.to_datetime(df[column_name])
|
||||
|
||||
# reuse session when loading data if possible, to make import atomic
|
||||
if example_database.sqlalchemy_uri == get_main_database().sqlalchemy_uri:
|
||||
logger.info("Loading data inside the import transaction")
|
||||
connection = session.connection()
|
||||
else:
|
||||
logger.warning("Loading data outside the import transaction")
|
||||
connection = example_database.get_sqla_engine()
|
||||
|
||||
df.to_sql(
|
||||
dataset.table_name,
|
||||
con=connection,
|
||||
schema=dataset.schema,
|
||||
if_exists="replace",
|
||||
chunksize=CHUNKSIZE,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
)
|
||||
|
|
|
@ -24,9 +24,9 @@ from superset.commands.importers.v1.examples import ImportExamplesCommand
|
|||
YAML_EXTENSIONS = {".yaml", ".yml"}
|
||||
|
||||
|
||||
def load_from_configs() -> None:
|
||||
def load_from_configs(force_data: bool = False) -> None:
|
||||
contents = load_contents()
|
||||
command = ImportExamplesCommand(contents, overwrite=True)
|
||||
command = ImportExamplesCommand(contents, overwrite=True, force_data=force_data)
|
||||
command.run()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue