diff --git a/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx b/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx
index 52cf8c8ff7..ce0d3ea34f 100644
--- a/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx
+++ b/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx
@@ -34,17 +34,6 @@ const props = {
onChange: () => {},
};
-const extraColumn = {
- column_name: 'new_column',
- type: 'VARCHAR(10)',
- description: null,
- filterable: true,
- verbose_name: null,
- is_dttm: false,
- expression: '',
- groupby: true,
-};
-
const DATASOURCE_ENDPOINT = 'glob:*/datasource/external_metadata/*';
describe('DatasourceEditor', () => {
@@ -85,11 +74,65 @@ describe('DatasourceEditor', () => {
});
});
- it('merges columns', () => {
+ it('to add, remove and modify columns accordingly', () => {
+ const columns = [
+ {
+ name: 'ds',
+ type: 'DATETIME',
+ nullable: true,
+ default: '',
+ primary_key: false,
+ },
+ {
+ name: 'gender',
+ type: 'VARCHAR(32)',
+ nullable: true,
+ default: '',
+ primary_key: false,
+ },
+ {
+ name: 'new_column',
+ type: 'VARCHAR(10)',
+ nullable: true,
+ default: '',
+ primary_key: false,
+ },
+ ];
+
const numCols = props.datasource.columns.length;
expect(inst.state.databaseColumns).toHaveLength(numCols);
- inst.mergeColumns([extraColumn]);
- expect(inst.state.databaseColumns).toHaveLength(numCols + 1);
+ inst.updateColumns(columns);
+ expect(inst.state.databaseColumns).toEqual(
+ expect.arrayContaining([
+ {
+ type: 'DATETIME',
+ description: null,
+ filterable: false,
+ verbose_name: null,
+ is_dttm: true,
+ expression: '',
+ groupby: false,
+ column_name: 'ds',
+ },
+ {
+ type: 'VARCHAR(32)',
+ description: null,
+ filterable: true,
+ verbose_name: null,
+ is_dttm: false,
+ expression: '',
+ groupby: true,
+ column_name: 'gender',
+ },
+ expect.objectContaining({
+ column_name: 'new_column',
+ type: 'VARCHAR(10)',
+ }),
+ ]),
+ );
+ expect(inst.state.databaseColumns).not.toEqual(
+ expect.arrayContaining([expect.objectContaining({ name: 'name' })]),
+ );
});
it('renders isSqla fields', () => {
diff --git a/superset-frontend/src/datasource/DatasourceEditor.jsx b/superset-frontend/src/datasource/DatasourceEditor.jsx
index 8ce54307bd..dd20ca5532 100644
--- a/superset-frontend/src/datasource/DatasourceEditor.jsx
+++ b/superset-frontend/src/datasource/DatasourceEditor.jsx
@@ -172,7 +172,7 @@ function ColumnCollectionTable({
) : (
v
),
- type: d => ,
+ type: d => ,
is_dttm: checkboxGenerator,
filterable: checkboxGenerator,
groupby: checkboxGenerator,
@@ -289,29 +289,58 @@ export class DatasourceEditor extends React.PureComponent {
this.validate(this.onChange);
}
- mergeColumns(cols) {
- let { databaseColumns } = this.state;
- let hasChanged;
- const currentColNames = databaseColumns.map(col => col.column_name);
+ updateColumns(cols) {
+ const { databaseColumns } = this.state;
+ const databaseColumnNames = cols.map(col => col.name);
+ const currentCols = databaseColumns.reduce(
+ (agg, col) => ({
+ ...agg,
+ [col.column_name]: col,
+ }),
+ {},
+ );
+ const finalColumns = [];
+ const results = {
+ added: [],
+ modified: [],
+ removed: databaseColumns
+ .map(col => col.column_name)
+ .filter(col => !databaseColumnNames.includes(col)),
+ };
cols.forEach(col => {
- if (currentColNames.indexOf(col.name) < 0) {
- // Adding columns
- databaseColumns = databaseColumns.concat([
- {
- id: shortid.generate(),
- column_name: col.name,
- type: col.type,
- groupby: true,
- filterable: true,
- },
- ]);
- hasChanged = true;
+ const currentCol = currentCols[col.name];
+ if (!currentCol) {
+ // new column
+ finalColumns.push({
+ id: shortid.generate(),
+ column_name: col.name,
+ type: col.type,
+ groupby: true,
+ filterable: true,
+ });
+ results.added.push(col.name);
+ } else if (currentCol.type !== col.type) {
+ // modified column
+ finalColumns.push({
+ ...currentCol,
+ type: col.type,
+ });
+ results.modified.push(col.name);
+ } else {
+ // unchanged
+ finalColumns.push(currentCol);
}
});
- if (hasChanged) {
- this.setColumns({ databaseColumns });
+ if (
+ results.added.length ||
+ results.modified.length ||
+ results.removed.length
+ ) {
+ this.setColumns({ databaseColumns: finalColumns });
}
+ return results;
}
+
syncMetadata() {
const { datasource } = this.state;
// Handle carefully when the schema is empty
@@ -326,7 +355,19 @@ export class DatasourceEditor extends React.PureComponent {
SupersetClient.get({ endpoint })
.then(({ json }) => {
- this.mergeColumns(json);
+ const results = this.updateColumns(json);
+ if (results.modified.length)
+ this.props.addSuccessToast(
+ t('Modified columns: %s', results.modified.join(', ')),
+ );
+ if (results.removed.length)
+ this.props.addSuccessToast(
+ t('Removed columns: %s', results.removed.join(', ')),
+ );
+ if (results.added.length)
+ this.props.addSuccessToast(
+ t('New columns added: %s', results.added.join(', ')),
+ );
this.props.addSuccessToast(t('Metadata has been synced'));
this.setState({ metadataLoading: false });
})
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 97336d4b18..cf19d64aec 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -16,6 +16,7 @@
# under the License.
import logging
from collections import OrderedDict
+from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
@@ -82,6 +83,13 @@ class QueryStringExtended(NamedTuple):
sql: str
+@dataclass
+class MetadataResult:
+ added: List[str] = field(default_factory=list)
+ removed: List[str] = field(default_factory=list)
+ modified: List[str] = field(default_factory=list)
+
+
class AnnotationDatasource(BaseDatasource):
""" Dummy object so we can query annotations using 'Viz' objects just like
regular datasources.
@@ -1243,10 +1251,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def get_sqla_table_object(self) -> Table:
return self.database.get_table(self.table_name, schema=self.schema)
- def fetch_metadata(self, commit: bool = True) -> None:
- """Fetches the metadata for the table and merges it in"""
+ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
+ """
+ Fetches the metadata for the table and merges it in
+
+ :param commit: should the changes be committed or not.
+ :return: Tuple with lists of added, removed and modified column names.
+ """
try:
- table_ = self.get_sqla_table_object()
+ new_table = self.get_sqla_table_object()
except SQLAlchemyError:
raise QueryObjectValidationError(
_(
@@ -1260,35 +1273,46 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
any_date_col = None
db_engine_spec = self.database.db_engine_spec
db_dialect = self.database.get_dialect()
- dbcols = (
- db.session.query(TableColumn)
- .filter(TableColumn.table == self)
- .filter(or_(TableColumn.column_name == col.name for col in table_.columns))
- )
- dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
+ old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)
- for col in table_.columns:
+ old_columns_by_name = {col.column_name: col for col in old_columns}
+ results = MetadataResult(
+ removed=[
+ col
+ for col in old_columns_by_name
+ if col not in {col.name for col in new_table.columns}
+ ]
+ )
+
+ # clear old columns before adding modified columns back
+ self.columns = []
+ for col in new_table.columns:
try:
datatype = db_engine_spec.column_datatype_to_string(
col.type, db_dialect
)
except Exception as ex: # pylint: disable=broad-except
datatype = "UNKNOWN"
- logger.error("Unrecognized data type in %s.%s", table_, col.name)
+ logger.error("Unrecognized data type in %s.%s", new_table, col.name)
logger.exception(ex)
- dbcol = dbcols.get(col.name, None)
- if not dbcol:
- dbcol = TableColumn(column_name=col.name, type=datatype, table=self)
- dbcol.is_dttm = dbcol.is_temporal
- db_engine_spec.alter_new_orm_column(dbcol)
+ old_column = old_columns_by_name.get(col.name, None)
+ if not old_column:
+ results.added.append(col.name)
+ new_column = TableColumn(
+ column_name=col.name, type=datatype, table=self
+ )
+ new_column.is_dttm = new_column.is_temporal
+ db_engine_spec.alter_new_orm_column(new_column)
else:
- dbcol.type = datatype
- dbcol.groupby = True
- dbcol.filterable = True
- self.columns.append(dbcol)
- if not any_date_col and dbcol.is_temporal:
+ new_column = old_column
+ if new_column.type != datatype:
+ results.modified.append(col.name)
+ new_column.type = datatype
+ new_column.groupby = True
+ new_column.filterable = True
+ self.columns.append(new_column)
+ if not any_date_col and new_column.is_temporal:
any_date_col = col.name
-
metrics.append(
SqlMetric(
metric_name="count",
@@ -1307,6 +1331,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
db.session.merge(self)
if commit:
db.session.commit()
+ return results
@classmethod
def import_obj(
diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py
index 5f3466fb5f..ae2a7c41bf 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -17,7 +17,8 @@
"""Views used by the SqlAlchemy connector"""
import logging
import re
-from typing import List, Union
+from dataclasses import dataclass, field
+from typing import Dict, List, Union
from flask import flash, Markup, redirect
from flask_appbuilder import CompactCRUDMixin, expose
@@ -421,30 +422,78 @@ class TableModelView( # pylint: disable=too-many-ancestors
@action(
"refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh"
)
- def refresh( # pylint: disable=no-self-use
+ def refresh( # pylint: disable=no-self-use, too-many-branches
self, tables: Union["TableModelView", List["TableModelView"]]
) -> FlaskResponse:
if not isinstance(tables, list):
tables = [tables]
- successes = []
- failures = []
+
+ @dataclass
+ class RefreshResults:
+ successes: List[TableModelView] = field(default_factory=list)
+ failures: List[TableModelView] = field(default_factory=list)
+ added: Dict[str, List[str]] = field(default_factory=dict)
+ removed: Dict[str, List[str]] = field(default_factory=dict)
+ modified: Dict[str, List[str]] = field(default_factory=dict)
+
+ results = RefreshResults()
+
for table_ in tables:
try:
- table_.fetch_metadata()
- successes.append(table_)
+ metadata_results = table_.fetch_metadata()
+ if metadata_results.added:
+ results.added[table_.table_name] = metadata_results.added
+ if metadata_results.removed:
+ results.removed[table_.table_name] = metadata_results.removed
+ if metadata_results.modified:
+ results.modified[table_.table_name] = metadata_results.modified
+ results.successes.append(table_)
except Exception: # pylint: disable=broad-except
- failures.append(table_)
+ results.failures.append(table_)
- if len(successes) > 0:
+ if len(results.successes) > 0:
success_msg = _(
"Metadata refreshed for the following table(s): %(tables)s",
- tables=", ".join([t.table_name for t in successes]),
+ tables=", ".join([t.table_name for t in results.successes]),
)
flash(success_msg, "info")
- if len(failures) > 0:
+ if results.added:
+ added_tables = []
+ for table, cols in results.added.items():
+ added_tables.append(f"{table} ({', '.join(cols)})")
+ flash(
+ _(
+ "The following tables added new columns: %(tables)s",
+ tables=", ".join(added_tables),
+ ),
+ "info",
+ )
+ if results.removed:
+ removed_tables = []
+ for table, cols in results.removed.items():
+ removed_tables.append(f"{table} ({', '.join(cols)})")
+ flash(
+ _(
+ "The following tables removed columns: %(tables)s",
+ tables=", ".join(removed_tables),
+ ),
+ "info",
+ )
+ if results.modified:
+ modified_tables = []
+ for table, cols in results.modified.items():
+ modified_tables.append(f"{table} ({', '.join(cols)})")
+ flash(
+ _(
+ "The following tables update column metadata: %(tables)s",
+ tables=", ".join(modified_tables),
+ ),
+ "info",
+ )
+ if len(results.failures) > 0:
failure_msg = _(
- "Unable to retrieve metadata for the following table(s): %(tables)s",
- tables=", ".join([t.table_name for t in failures]),
+ "Unable to refresh metadata for the following table(s): %(tables)s",
+ tables=", ".join([t.table_name for t in results.failures]),
)
flash(failure_msg, "danger")