diff --git a/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx b/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx index 52cf8c8ff7..ce0d3ea34f 100644 --- a/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx +++ b/superset-frontend/spec/javascripts/datasource/DatasourceEditor_spec.jsx @@ -34,17 +34,6 @@ const props = { onChange: () => {}, }; -const extraColumn = { - column_name: 'new_column', - type: 'VARCHAR(10)', - description: null, - filterable: true, - verbose_name: null, - is_dttm: false, - expression: '', - groupby: true, -}; - const DATASOURCE_ENDPOINT = 'glob:*/datasource/external_metadata/*'; describe('DatasourceEditor', () => { @@ -85,11 +74,65 @@ describe('DatasourceEditor', () => { }); }); - it('merges columns', () => { + it('to add, remove and modify columns accordingly', () => { + const columns = [ + { + name: 'ds', + type: 'DATETIME', + nullable: true, + default: '', + primary_key: false, + }, + { + name: 'gender', + type: 'VARCHAR(32)', + nullable: true, + default: '', + primary_key: false, + }, + { + name: 'new_column', + type: 'VARCHAR(10)', + nullable: true, + default: '', + primary_key: false, + }, + ]; + const numCols = props.datasource.columns.length; expect(inst.state.databaseColumns).toHaveLength(numCols); - inst.mergeColumns([extraColumn]); - expect(inst.state.databaseColumns).toHaveLength(numCols + 1); + inst.updateColumns(columns); + expect(inst.state.databaseColumns).toEqual( + expect.arrayContaining([ + { + type: 'DATETIME', + description: null, + filterable: false, + verbose_name: null, + is_dttm: true, + expression: '', + groupby: false, + column_name: 'ds', + }, + { + type: 'VARCHAR(32)', + description: null, + filterable: true, + verbose_name: null, + is_dttm: false, + expression: '', + groupby: true, + column_name: 'gender', + }, + expect.objectContaining({ + column_name: 'new_column', + type: 'VARCHAR(10)', + }), + ]), + ); + expect(inst.state.databaseColumns).not.toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'name' })]), + ); }); it('renders isSqla fields', () => { diff --git a/superset-frontend/src/datasource/DatasourceEditor.jsx b/superset-frontend/src/datasource/DatasourceEditor.jsx index 8ce54307bd..dd20ca5532 100644 --- a/superset-frontend/src/datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/datasource/DatasourceEditor.jsx @@ -172,7 +172,7 @@ function ColumnCollectionTable({ ) : ( v ), - type: d => , + type: d => , is_dttm: checkboxGenerator, filterable: checkboxGenerator, groupby: checkboxGenerator, @@ -289,29 +289,58 @@ export class DatasourceEditor extends React.PureComponent { this.validate(this.onChange); } - mergeColumns(cols) { - let { databaseColumns } = this.state; - let hasChanged; - const currentColNames = databaseColumns.map(col => col.column_name); + updateColumns(cols) { + const { databaseColumns } = this.state; + const databaseColumnNames = cols.map(col => col.name); + const currentCols = databaseColumns.reduce( + (agg, col) => ({ + ...agg, + [col.column_name]: col, + }), + {}, + ); + const finalColumns = []; + const results = { + added: [], + modified: [], + removed: databaseColumns + .map(col => col.column_name) + .filter(col => !databaseColumnNames.includes(col)), + }; cols.forEach(col => { - if (currentColNames.indexOf(col.name) < 0) { - // Adding columns - databaseColumns = databaseColumns.concat([ - { - id: shortid.generate(), - column_name: col.name, - type: col.type, - groupby: true, - filterable: true, - }, - ]); - hasChanged = true; + const currentCol = currentCols[col.name]; + if (!currentCol) { + // new column + finalColumns.push({ + id: shortid.generate(), + column_name: col.name, + type: col.type, + groupby: true, + filterable: true, + }); + results.added.push(col.name); + } else if (currentCol.type !== col.type) { + // modified column + finalColumns.push({ + ...currentCol, + type: col.type, + }); + results.modified.push(col.name); + } else { + // unchanged + finalColumns.push(currentCol); } }); - if (hasChanged) { - this.setColumns({ databaseColumns }); + if ( + results.added.length || + results.modified.length || + results.removed.length + ) { + this.setColumns({ databaseColumns: finalColumns }); } + return results; } + syncMetadata() { const { datasource } = this.state; // Handle carefully when the schema is empty @@ -326,7 +355,19 @@ export class DatasourceEditor extends React.PureComponent { SupersetClient.get({ endpoint }) .then(({ json }) => { - this.mergeColumns(json); + const results = this.updateColumns(json); + if (results.modified.length) + this.props.addSuccessToast( + t('Modified columns: %s', results.modified.join(', ')), + ); + if (results.removed.length) + this.props.addSuccessToast( + t('Removed columns: %s', results.removed.join(', ')), + ); + if (results.added.length) + this.props.addSuccessToast( + t('New columns added: %s', results.added.join(', ')), + ); this.props.addSuccessToast(t('Metadata has been synced')); this.setState({ metadataLoading: false }); }) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 97336d4b18..cf19d64aec 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -16,6 +16,7 @@ # under the License. import logging from collections import OrderedDict +from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union @@ -82,6 +83,13 @@ class QueryStringExtended(NamedTuple): sql: str +@dataclass +class MetadataResult: + added: List[str] = field(default_factory=list) + removed: List[str] = field(default_factory=list) + modified: List[str] = field(default_factory=list) + + class AnnotationDatasource(BaseDatasource): """ Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -1243,10 +1251,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) - def fetch_metadata(self, commit: bool = True) -> None: - """Fetches the metadata for the table and merges it in""" + def fetch_metadata(self, commit: bool = True) -> MetadataResult: + """ + Fetches the metadata for the table and merges it in + + :param commit: should the changes be committed or not. + :return: Tuple with lists of added, removed and modified column names. + """ try: - table_ = self.get_sqla_table_object() + new_table = self.get_sqla_table_object() except SQLAlchemyError: raise QueryObjectValidationError( _( @@ -1260,35 +1273,46 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at any_date_col = None db_engine_spec = self.database.db_engine_spec db_dialect = self.database.get_dialect() - dbcols = ( - db.session.query(TableColumn) - .filter(TableColumn.table == self) - .filter(or_(TableColumn.column_name == col.name for col in table_.columns)) - ) - dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} + old_columns = db.session.query(TableColumn).filter(TableColumn.table == self) - for col in table_.columns: + old_columns_by_name = {col.column_name: col for col in old_columns} + results = MetadataResult( + removed=[ + col + for col in old_columns_by_name + if col not in {col.name for col in new_table.columns} + ] + ) + + # clear old columns before adding modified columns back + self.columns = [] + for col in new_table.columns: try: datatype = db_engine_spec.column_datatype_to_string( col.type, db_dialect ) except Exception as ex: # pylint: disable=broad-except datatype = "UNKNOWN" - logger.error("Unrecognized data type in %s.%s", table_, col.name) + logger.error("Unrecognized data type in %s.%s", new_table, col.name) logger.exception(ex) - dbcol = dbcols.get(col.name, None) - if not dbcol: - dbcol = TableColumn(column_name=col.name, type=datatype, table=self) - dbcol.is_dttm = dbcol.is_temporal - db_engine_spec.alter_new_orm_column(dbcol) + old_column = old_columns_by_name.get(col.name, None) + if not old_column: + results.added.append(col.name) + new_column = TableColumn( + column_name=col.name, type=datatype, table=self + ) + new_column.is_dttm = new_column.is_temporal + db_engine_spec.alter_new_orm_column(new_column) else: - dbcol.type = datatype - dbcol.groupby = True - dbcol.filterable = True - self.columns.append(dbcol) - if not any_date_col and dbcol.is_temporal: + new_column = old_column + if new_column.type != datatype: + results.modified.append(col.name) + new_column.type = datatype + new_column.groupby = True + new_column.filterable = True + self.columns.append(new_column) + if not any_date_col and new_column.is_temporal: any_date_col = col.name - metrics.append( SqlMetric( metric_name="count", @@ -1307,6 +1331,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at db.session.merge(self) if commit: db.session.commit() + return results @classmethod def import_obj( diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 5f3466fb5f..ae2a7c41bf 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -17,7 +17,8 @@ """Views used by the SqlAlchemy connector""" import logging import re -from typing import List, Union +from dataclasses import dataclass, field +from typing import Dict, List, Union from flask import flash, Markup, redirect from flask_appbuilder import CompactCRUDMixin, expose @@ -421,30 +422,78 @@ class TableModelView( # pylint: disable=too-many-ancestors @action( "refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh" ) - def refresh( # pylint: disable=no-self-use + def refresh( # pylint: disable=no-self-use, too-many-branches self, tables: Union["TableModelView", List["TableModelView"]] ) -> FlaskResponse: if not isinstance(tables, list): tables = [tables] - successes = [] - failures = [] + + @dataclass + class RefreshResults: + successes: List[TableModelView] = field(default_factory=list) + failures: List[TableModelView] = field(default_factory=list) + added: Dict[str, List[str]] = field(default_factory=dict) + removed: Dict[str, List[str]] = field(default_factory=dict) + modified: Dict[str, List[str]] = field(default_factory=dict) + + results = RefreshResults() + for table_ in tables: try: - table_.fetch_metadata() - successes.append(table_) + metadata_results = table_.fetch_metadata() + if metadata_results.added: + results.added[table_.table_name] = metadata_results.added + if metadata_results.removed: + results.removed[table_.table_name] = metadata_results.removed + if metadata_results.modified: + results.modified[table_.table_name] = metadata_results.modified + results.successes.append(table_) except Exception: # pylint: disable=broad-except - failures.append(table_) + results.failures.append(table_) - if len(successes) > 0: + if len(results.successes) > 0: success_msg = _( "Metadata refreshed for the following table(s): %(tables)s", - tables=", ".join([t.table_name for t in successes]), + tables=", ".join([t.table_name for t in results.successes]), ) flash(success_msg, "info") - if len(failures) > 0: + if results.added: + added_tables = [] + for table, cols in results.added.items(): + added_tables.append(f"{table} ({', '.join(cols)})") + flash( + _( + "The following tables added new columns: %(tables)s", + tables=", ".join(added_tables), + ), + "info", + ) + if results.removed: + removed_tables = [] + for table, cols in results.removed.items(): + removed_tables.append(f"{table} ({', '.join(cols)})") + flash( + _( + "The following tables removed columns: %(tables)s", + tables=", ".join(removed_tables), + ), + "info", + ) + if results.modified: + modified_tables = [] + for table, cols in results.modified.items(): + modified_tables.append(f"{table} ({', '.join(cols)})") + flash( + _( + "The following tables update column metadata: %(tables)s", + tables=", ".join(modified_tables), + ), + "info", + ) + if len(results.failures) > 0: failure_msg = _( - "Unable to retrieve metadata for the following table(s): %(tables)s", - tables=", ".join([t.table_name for t in failures]), + "Unable to refresh metadata for the following table(s): %(tables)s", + tables=", ".join([t.table_name for t in results.failures]), ) flash(failure_msg, "danger")