= ({
};
const renderPasswordFields = () => {
- if (passwordFields.length === 0) {
+ if (
+ passwordFields.length === 0 &&
+ sshTunnelPasswordFields.length === 0 &&
+ sshTunnelPrivateKeyFields.length === 0 &&
+ sshTunnelPrivateKeyPasswordFields.length === 0
+ ) {
return null;
}
+ const files = [
+ ...new Set([
+ ...passwordFields,
+ ...sshTunnelPasswordFields,
+ ...sshTunnelPrivateKeyFields,
+ ...sshTunnelPrivateKeyPasswordFields,
+ ]),
+ ];
+
return (
<>
{t('Database passwords')}
{passwordsNeededMessage}
- {passwordFields.map(fileName => (
-
-
- {fileName}
- *
-
-
- setPasswords({ ...passwords, [fileName]: event.target.value })
- }
- />
-
+ {files.map(fileName => (
+ <>
+ {passwordFields?.indexOf(fileName) >= 0 && (
+
+
+ {t('%s PASSWORD', fileName.slice(10))}
+ *
+
+
+ setPasswords({
+ ...passwords,
+ [fileName]: event.target.value,
+ })
+ }
+ />
+
+ )}
+ {sshTunnelPasswordFields?.indexOf(fileName) >= 0 && (
+
+
+ {t('%s SSH TUNNEL PASSWORD', fileName.slice(10))}
+ *
+
+
+ setSSHTunnelPasswords({
+ ...sshTunnelPasswords,
+ [fileName]: event.target.value,
+ })
+ }
+ data-test="ssh_tunnel_password"
+ />
+
+ )}
+ {sshTunnelPrivateKeyFields?.indexOf(fileName) >= 0 && (
+
+
+ {t('%s SSH TUNNEL PRIVATE KEY', fileName.slice(10))}
+ *
+
+
+ )}
+ {sshTunnelPrivateKeyPasswordFields?.indexOf(fileName) >= 0 && (
+
+
+ {t('%s SSH TUNNEL PRIVATE KEY PASSWORD', fileName.slice(10))}
+ *
+
+
+ setSSHTunnelPrivateKeyPasswords({
+ ...sshTunnelPrivateKeyPasswords,
+ [fileName]: event.target.value,
+ })
+ }
+ data-test="ssh_tunnel_private_key_password"
+ />
+
+ )}
+ >
))}
>
);
@@ -303,7 +448,12 @@ const ImportModelsModal: FunctionComponent = ({
{errorMessage && (
0}
+ showDbInstallInstructions={
+ passwordFields.length > 0 ||
+ sshTunnelPasswordFields.length > 0 ||
+ sshTunnelPrivateKeyFields.length > 0 ||
+ sshTunnelPrivateKeyPasswordFields.length > 0
+ }
/>
)}
{renderPasswordFields()}
diff --git a/superset-frontend/src/pages/ChartList/index.tsx b/superset-frontend/src/pages/ChartList/index.tsx
index e02d848a5d..9788d61e4a 100644
--- a/superset-frontend/src/pages/ChartList/index.tsx
+++ b/superset-frontend/src/pages/ChartList/index.tsx
@@ -197,6 +197,16 @@ function ChartList(props: ChartListProps) {
const [importingChart, showImportModal] = useState(false);
const [passwordFields, setPasswordFields] = useState([]);
const [preparingExport, setPreparingExport] = useState(false);
+ const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
+ string[]
+ >([]);
+ const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
+ string[]
+ >([]);
+ const [
+ sshTunnelPrivateKeyPasswordFields,
+ setSSHTunnelPrivateKeyPasswordFields,
+ ] = useState([]);
// TODO: Fix usage of localStorage keying on the user id
const userSettings = dangerouslyGetItemDoNotUse(userId?.toString(), null) as {
@@ -888,6 +898,14 @@ function ChartList(props: ChartListProps) {
onHide={closeChartImportModal}
passwordFields={passwordFields}
setPasswordFields={setPasswordFields}
+ sshTunnelPasswordFields={sshTunnelPasswordFields}
+ setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
+ sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
+ setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
+ sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
+ setSSHTunnelPrivateKeyPasswordFields={
+ setSSHTunnelPrivateKeyPasswordFields
+ }
/>
{preparingExport && }
>
diff --git a/superset-frontend/src/views/CRUD/dashboard/DashboardList.tsx b/superset-frontend/src/views/CRUD/dashboard/DashboardList.tsx
index d26900a29d..d6d192e22b 100644
--- a/superset-frontend/src/views/CRUD/dashboard/DashboardList.tsx
+++ b/superset-frontend/src/views/CRUD/dashboard/DashboardList.tsx
@@ -145,6 +145,16 @@ function DashboardList(props: DashboardListProps) {
const [preparingExport, setPreparingExport] = useState(false);
const enableBroadUserAccess =
bootstrapData?.common?.conf?.ENABLE_BROAD_ACTIVITY_ACCESS;
+ const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
+ string[]
+ >([]);
+ const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
+ string[]
+ >([]);
+ const [
+ sshTunnelPrivateKeyPasswordFields,
+ setSSHTunnelPrivateKeyPasswordFields,
+ ] = useState([]);
const openDashboardImportModal = () => {
showImportModal(true);
@@ -789,6 +799,14 @@ function DashboardList(props: DashboardListProps) {
onHide={closeDashboardImportModal}
passwordFields={passwordFields}
setPasswordFields={setPasswordFields}
+ sshTunnelPasswordFields={sshTunnelPasswordFields}
+ setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
+ sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
+ setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
+ sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
+ setSSHTunnelPrivateKeyPasswordFields={
+ setSSHTunnelPrivateKeyPasswordFields
+ }
/>
{preparingExport && }
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
index 35151598e0..06a0cb349c 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
@@ -555,11 +555,29 @@ const DatabaseModal: FunctionComponent = ({
const [isLoading, setLoading] = useState(false);
const [testInProgress, setTestInProgress] = useState(false);
const [passwords, setPasswords] = useState>({});
+ const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
+ Record
+ >({});
+ const [sshTunnelPrivateKeys, setSSHTunnelPrivateKeys] = useState<
+ Record
+ >({});
+ const [sshTunnelPrivateKeyPasswords, setSSHTunnelPrivateKeyPasswords] =
+ useState>({});
const [confirmedOverwrite, setConfirmedOverwrite] = useState(false);
const [fileList, setFileList] = useState([]);
const [importingModal, setImportingModal] = useState(false);
const [importingErrorMessage, setImportingErrorMessage] = useState();
const [passwordFields, setPasswordFields] = useState([]);
+ const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
+ string[]
+ >([]);
+ const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
+ string[]
+ >([]);
+ const [
+ sshTunnelPrivateKeyPasswordFields,
+ setSSHTunnelPrivateKeyPasswordFields,
+ ] = useState([]);
const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
@@ -657,7 +675,13 @@ const DatabaseModal: FunctionComponent = ({
setImportingModal(false);
setImportingErrorMessage('');
setPasswordFields([]);
+ setSSHTunnelPasswordFields([]);
+ setSSHTunnelPrivateKeyFields([]);
+ setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({});
+ setSSHTunnelPasswords({});
+ setSSHTunnelPrivateKeys({});
+ setSSHTunnelPrivateKeyPasswords({});
setConfirmedOverwrite(false);
setUseSSHTunneling(false);
onHide();
@@ -678,6 +702,9 @@ const DatabaseModal: FunctionComponent = ({
state: {
alreadyExists,
passwordsNeeded,
+ sshPasswordNeeded,
+ sshPrivateKeyNeeded,
+ sshPrivateKeyPasswordNeeded,
loading: importLoading,
failed: importErrored,
},
@@ -811,6 +838,9 @@ const DatabaseModal: FunctionComponent = ({
const dbId = await importResource(
fileList[0].originFileObj,
passwords,
+ sshTunnelPasswords,
+ sshTunnelPrivateKeys,
+ sshTunnelPrivateKeyPasswords,
confirmedOverwrite,
);
if (dbId) {
@@ -983,7 +1013,13 @@ const DatabaseModal: FunctionComponent = ({
setImportingModal(false);
setImportingErrorMessage('');
setPasswordFields([]);
+ setSSHTunnelPasswordFields([]);
+ setSSHTunnelPrivateKeyFields([]);
+ setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({});
+ setSSHTunnelPasswords({});
+ setSSHTunnelPrivateKeys({});
+ setSSHTunnelPrivateKeyPasswords({});
}
setDB({ type: ActionType.reset });
setFileList([]);
@@ -993,7 +1029,13 @@ const DatabaseModal: FunctionComponent = ({
if (
importLoading ||
(alreadyExists.length && !confirmedOverwrite) ||
- (passwordsNeeded.length && JSON.stringify(passwords) === '{}')
+ (passwordsNeeded.length && JSON.stringify(passwords) === '{}') ||
+ (sshPasswordNeeded.length &&
+ JSON.stringify(sshTunnelPasswords) === '{}') ||
+ (sshPrivateKeyNeeded.length &&
+ JSON.stringify(sshTunnelPrivateKeys) === '{}') ||
+ (sshPrivateKeyPasswordNeeded.length &&
+ JSON.stringify(sshTunnelPrivateKeyPasswords) === '{}')
)
return true;
return false;
@@ -1098,13 +1140,24 @@ const DatabaseModal: FunctionComponent = ({
!importLoading &&
!alreadyExists.length &&
!passwordsNeeded.length &&
+ !sshPasswordNeeded.length &&
+ !sshPrivateKeyNeeded.length &&
+ !sshPrivateKeyPasswordNeeded.length &&
!isLoading && // This prevents a double toast for non-related imports
!importErrored // This prevents a success toast on error
) {
onClose();
addSuccessToast(t('Database connected'));
}
- }, [alreadyExists, passwordsNeeded, importLoading, importErrored]);
+ }, [
+ alreadyExists,
+ passwordsNeeded,
+ importLoading,
+ importErrored,
+ sshPasswordNeeded,
+ sshPrivateKeyNeeded,
+ sshPrivateKeyPasswordNeeded,
+ ]);
useEffect(() => {
if (show) {
@@ -1153,6 +1206,18 @@ const DatabaseModal: FunctionComponent = ({
setPasswordFields([...passwordsNeeded]);
}, [passwordsNeeded]);
+ useEffect(() => {
+ setSSHTunnelPasswordFields([...sshPasswordNeeded]);
+ }, [sshPasswordNeeded]);
+
+ useEffect(() => {
+ setSSHTunnelPrivateKeyFields([...sshPrivateKeyNeeded]);
+ }, [sshPrivateKeyNeeded]);
+
+ useEffect(() => {
+ setSSHTunnelPrivateKeyPasswordFields([...sshPrivateKeyPasswordNeeded]);
+ }, [sshPrivateKeyPasswordNeeded]);
+
useEffect(() => {
if (db && isSSHTunneling) {
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
@@ -1162,7 +1227,13 @@ const DatabaseModal: FunctionComponent = ({
const onDbImport = async (info: UploadChangeParam) => {
setImportingErrorMessage('');
setPasswordFields([]);
+ setSSHTunnelPasswordFields([]);
+ setSSHTunnelPrivateKeyFields([]);
+ setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({});
+ setSSHTunnelPasswords({});
+ setSSHTunnelPrivateKeys({});
+ setSSHTunnelPrivateKeyPasswords({});
setImportingModal(true);
setFileList([
{
@@ -1175,15 +1246,33 @@ const DatabaseModal: FunctionComponent = ({
const dbId = await importResource(
info.file.originFileObj,
passwords,
+ sshTunnelPasswords,
+ sshTunnelPrivateKeys,
+ sshTunnelPrivateKeyPasswords,
confirmedOverwrite,
);
if (dbId) onDatabaseAdd?.();
};
const passwordNeededField = () => {
- if (!passwordFields.length) return null;
+ if (
+ !passwordFields.length &&
+ !sshTunnelPasswordFields.length &&
+ !sshTunnelPrivateKeyFields.length &&
+ !sshTunnelPrivateKeyPasswordFields.length
+ )
+ return null;
- return passwordFields.map(database => (
+ const files = [
+ ...new Set([
+ ...passwordFields,
+ ...sshTunnelPasswordFields,
+ ...sshTunnelPrivateKeyFields,
+ ...sshTunnelPrivateKeyPasswordFields,
+ ]),
+ ];
+
+ return files.map(database => (
<>
= ({
)}
/>
- ) =>
- setPasswords({ ...passwords, [database]: event.target.value })
- }
- validationMethods={{ onBlur: () => {} }}
- errorMessage={validationErrors?.password_needed}
- label={t('%s PASSWORD', database.slice(10))}
- css={formScrollableStyles}
- />
+ {passwordFields?.indexOf(database) >= 0 && (
+ ) =>
+ setPasswords({ ...passwords, [database]: event.target.value })
+ }
+ validationMethods={{ onBlur: () => {} }}
+ errorMessage={validationErrors?.password_needed}
+ label={t('%s PASSWORD', database.slice(10))}
+ css={formScrollableStyles}
+ />
+ )}
+ {sshTunnelPasswordFields?.indexOf(database) >= 0 && (
+ ) =>
+ setSSHTunnelPasswords({
+ ...sshTunnelPasswords,
+ [database]: event.target.value,
+ })
+ }
+ validationMethods={{ onBlur: () => {} }}
+ errorMessage={validationErrors?.ssh_tunnel_password_needed}
+ label={t('%s SSH TUNNEL PASSWORD', database.slice(10))}
+ css={formScrollableStyles}
+ />
+ )}
+ {sshTunnelPrivateKeyFields?.indexOf(database) >= 0 && (
+ ) =>
+ setSSHTunnelPrivateKeys({
+ ...sshTunnelPrivateKeys,
+ [database]: event.target.value,
+ })
+ }
+ validationMethods={{ onBlur: () => {} }}
+ errorMessage={validationErrors?.ssh_tunnel_private_key_needed}
+ label={t('%s SSH TUNNEL PRIVATE KEY', database.slice(10))}
+ css={formScrollableStyles}
+ />
+ )}
+ {sshTunnelPrivateKeyPasswordFields?.indexOf(database) >= 0 && (
+ ) =>
+ setSSHTunnelPrivateKeyPasswords({
+ ...sshTunnelPrivateKeyPasswords,
+ [database]: event.target.value,
+ })
+ }
+ validationMethods={{ onBlur: () => {} }}
+ errorMessage={
+ validationErrors?.ssh_tunnel_private_key_password_needed
+ }
+ label={t('%s SSH TUNNEL PRIVATE KEY PASSWORD', database.slice(10))}
+ css={formScrollableStyles}
+ />
+ )}
>
));
};
@@ -1468,7 +1615,14 @@ const DatabaseModal: FunctionComponent = ({
);
};
- if (fileList.length > 0 && (alreadyExists.length || passwordFields.length)) {
+ if (
+ fileList.length > 0 &&
+ (alreadyExists.length ||
+ passwordFields.length ||
+ sshTunnelPasswordFields.length ||
+ sshTunnelPrivateKeyFields.length ||
+ sshTunnelPrivateKeyPasswordFields.length)
+ ) {
return (
[
diff --git a/superset-frontend/src/views/CRUD/data/dataset/DatasetList.tsx b/superset-frontend/src/views/CRUD/data/dataset/DatasetList.tsx
index c2236f403b..bc3342f69c 100644
--- a/superset-frontend/src/views/CRUD/data/dataset/DatasetList.tsx
+++ b/superset-frontend/src/views/CRUD/data/dataset/DatasetList.tsx
@@ -163,6 +163,16 @@ const DatasetList: FunctionComponent = ({
const [importingDataset, showImportModal] = useState(false);
const [passwordFields, setPasswordFields] = useState([]);
const [preparingExport, setPreparingExport] = useState(false);
+ const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
+ string[]
+ >([]);
+ const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
+ string[]
+ >([]);
+ const [
+ sshTunnelPrivateKeyPasswordFields,
+ setSSHTunnelPrivateKeyPasswordFields,
+ ] = useState([]);
const openDatasetImportModal = () => {
showImportModal(true);
@@ -822,6 +832,14 @@ const DatasetList: FunctionComponent = ({
onHide={closeDatasetImportModal}
passwordFields={passwordFields}
setPasswordFields={setPasswordFields}
+ sshTunnelPasswordFields={sshTunnelPasswordFields}
+ setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
+ sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
+ setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
+ sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
+ setSSHTunnelPrivateKeyPasswordFields={
+ setSSHTunnelPrivateKeyPasswordFields
+ }
/>
{preparingExport && }
>
diff --git a/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryList.tsx b/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryList.tsx
index 3409710db5..d3c96d4c30 100644
--- a/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryList.tsx
+++ b/superset-frontend/src/views/CRUD/data/savedquery/SavedQueryList.tsx
@@ -115,6 +115,16 @@ function SavedQueryList({
const [importingSavedQuery, showImportModal] = useState(false);
const [passwordFields, setPasswordFields] = useState([]);
const [preparingExport, setPreparingExport] = useState(false);
+ const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
+ string[]
+ >([]);
+ const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
+ string[]
+ >([]);
+ const [
+ sshTunnelPrivateKeyPasswordFields,
+ setSSHTunnelPrivateKeyPasswordFields,
+ ] = useState([]);
const openSavedQueryImportModal = () => {
showImportModal(true);
@@ -577,6 +587,14 @@ function SavedQueryList({
onHide={closeSavedQueryImportModal}
passwordFields={passwordFields}
setPasswordFields={setPasswordFields}
+ sshTunnelPasswordFields={sshTunnelPasswordFields}
+ setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
+ sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
+ setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
+ sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
+ setSSHTunnelPrivateKeyPasswordFields={
+ setSSHTunnelPrivateKeyPasswordFields
+ }
/>
{preparingExport && }
>
diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts
index 80a6c4793b..6812d1e0c5 100644
--- a/superset-frontend/src/views/CRUD/hooks.ts
+++ b/superset-frontend/src/views/CRUD/hooks.ts
@@ -25,6 +25,9 @@ import {
getAlreadyExists,
getPasswordsNeeded,
hasTerminalValidation,
+ getSSHPasswordsNeeded,
+ getSSHPrivateKeysNeeded,
+ getSSHPrivateKeyPasswordsNeeded,
} from 'src/views/CRUD/utils';
import { FetchDataConfig } from 'src/components/ListView';
import { FilterValue } from 'src/components/ListView/types';
@@ -386,6 +389,9 @@ interface ImportResourceState {
loading: boolean;
passwordsNeeded: string[];
alreadyExists: string[];
+ sshPasswordNeeded: string[];
+ sshPrivateKeyNeeded: string[];
+ sshPrivateKeyPasswordNeeded: string[];
failed: boolean;
}
@@ -398,6 +404,9 @@ export function useImportResource(
loading: false,
passwordsNeeded: [],
alreadyExists: [],
+ sshPasswordNeeded: [],
+ sshPrivateKeyNeeded: [],
+ sshPrivateKeyPasswordNeeded: [],
failed: false,
});
@@ -409,6 +418,9 @@ export function useImportResource(
(
bundle: File,
databasePasswords: Record = {},
+ sshTunnelPasswords: Record = {},
+ sshTunnelPrivateKey: Record = {},
+ sshTunnelPrivateKeyPasswords: Record = {},
overwrite = false,
) => {
// Set loading state
@@ -436,6 +448,33 @@ export function useImportResource(
if (overwrite) {
formData.append('overwrite', 'true');
}
+ /* The import bundle may contain ssh tunnel passwords; if required
+ * they should be provided by the user during import.
+ */
+ if (sshTunnelPasswords) {
+ formData.append(
+ 'ssh_tunnel_passwords',
+ JSON.stringify(sshTunnelPasswords),
+ );
+ }
+ /* The import bundle may contain ssh tunnel private_key; if required
+ * they should be provided by the user during import.
+ */
+ if (sshTunnelPrivateKey) {
+ formData.append(
+ 'ssh_tunnel_private_keys',
+ JSON.stringify(sshTunnelPrivateKey),
+ );
+ }
+ /* The import bundle may contain ssh tunnel private_key_password; if required
+ * they should be provided by the user during import.
+ */
+ if (sshTunnelPrivateKeyPasswords) {
+ formData.append(
+ 'ssh_tunnel_private_key_passwords',
+ JSON.stringify(sshTunnelPrivateKeyPasswords),
+ );
+ }
return SupersetClient.post({
endpoint: `/api/v1/${resourceName}/import/`,
@@ -446,6 +485,9 @@ export function useImportResource(
updateState({
passwordsNeeded: [],
alreadyExists: [],
+ sshPasswordNeeded: [],
+ sshPrivateKeyNeeded: [],
+ sshPrivateKeyPasswordNeeded: [],
failed: false,
});
return true;
@@ -479,6 +521,11 @@ export function useImportResource(
} else {
updateState({
passwordsNeeded: getPasswordsNeeded(error.errors),
+ sshPasswordNeeded: getSSHPasswordsNeeded(error.errors),
+ sshPrivateKeyNeeded: getSSHPrivateKeysNeeded(error.errors),
+ sshPrivateKeyPasswordNeeded: getSSHPrivateKeyPasswordsNeeded(
+ error.errors,
+ ),
alreadyExists: getAlreadyExists(error.errors),
});
}
diff --git a/superset-frontend/src/views/CRUD/utils.test.tsx b/superset-frontend/src/views/CRUD/utils.test.tsx
index fa41455d85..b9b8047f43 100644
--- a/superset-frontend/src/views/CRUD/utils.test.tsx
+++ b/superset-frontend/src/views/CRUD/utils.test.tsx
@@ -22,9 +22,15 @@ import {
getAlreadyExists,
getFilterValues,
getPasswordsNeeded,
+ getSSHPasswordsNeeded,
+ getSSHPrivateKeysNeeded,
+ getSSHPrivateKeyPasswordsNeeded,
hasTerminalValidation,
isAlreadyExists,
isNeedsPassword,
+ isNeedsSSHPassword,
+ isNeedsSSHPrivateKey,
+ isNeedsSSHPrivateKeyPassword,
} from 'src/views/CRUD/utils';
import { User } from 'src/types/bootstrapTypes';
import { Filter, TableTab } from './types';
@@ -112,6 +118,72 @@ const passwordNeededErrors = {
],
};
+const sshTunnelPasswordNeededErrors = {
+ errors: [
+ {
+ message: 'Error importing database',
+ error_type: 'GENERIC_COMMAND_ERROR',
+ level: 'warning',
+ extra: {
+ 'databases/imported_database.yaml': {
+ _schema: ['Must provide a password for the ssh tunnel'],
+ },
+ issue_codes: [
+ {
+ code: 1010,
+ message:
+ 'Issue 1010 - Superset encountered an error while running a command.',
+ },
+ ],
+ },
+ },
+ ],
+};
+
+const sshTunnelPrivateKeyNeededErrors = {
+ errors: [
+ {
+ message: 'Error importing database',
+ error_type: 'GENERIC_COMMAND_ERROR',
+ level: 'warning',
+ extra: {
+ 'databases/imported_database.yaml': {
+ _schema: ['Must provide a private key for the ssh tunnel'],
+ },
+ issue_codes: [
+ {
+ code: 1010,
+ message:
+ 'Issue 1010 - Superset encountered an error while running a command.',
+ },
+ ],
+ },
+ },
+ ],
+};
+
+const sshTunnelPrivateKeyPasswordNeededErrors = {
+ errors: [
+ {
+ message: 'Error importing database',
+ error_type: 'GENERIC_COMMAND_ERROR',
+ level: 'warning',
+ extra: {
+ 'databases/imported_database.yaml': {
+ _schema: ['Must provide a private key password for the ssh tunnel'],
+ },
+ issue_codes: [
+ {
+ code: 1010,
+ message:
+ 'Issue 1010 - Superset encountered an error while running a command.',
+ },
+ ],
+ },
+ },
+ ],
+};
+
test('identifies error payloads indicating that password is needed', () => {
let needsPassword;
@@ -129,6 +201,63 @@ test('identifies error payloads indicating that password is needed', () => {
expect(needsPassword).toBe(false);
});
+test('identifies error payloads indicating that ssh_tunnel password is needed', () => {
+ let needsSSHTunnelPassword;
+
+ needsSSHTunnelPassword = isNeedsSSHPassword({
+ _schema: ['Must provide a password for the ssh tunnel'],
+ });
+ expect(needsSSHTunnelPassword).toBe(true);
+
+ needsSSHTunnelPassword = isNeedsSSHPassword(
+ 'Database already exists and `overwrite=true` was not passed',
+ );
+ expect(needsSSHTunnelPassword).toBe(false);
+
+ needsSSHTunnelPassword = isNeedsSSHPassword({
+ type: ['Must be equal to Database.'],
+ });
+ expect(needsSSHTunnelPassword).toBe(false);
+});
+
+test('identifies error payloads indicating that ssh_tunnel private_key is needed', () => {
+ let needsSSHTunnelPrivateKey;
+
+ needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey({
+ _schema: ['Must provide a private key for the ssh tunnel'],
+ });
+ expect(needsSSHTunnelPrivateKey).toBe(true);
+
+ needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey(
+ 'Database already exists and `overwrite=true` was not passed',
+ );
+ expect(needsSSHTunnelPrivateKey).toBe(false);
+
+ needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey({
+ type: ['Must be equal to Database.'],
+ });
+ expect(needsSSHTunnelPrivateKey).toBe(false);
+});
+
+test('identifies error payloads indicating that ssh_tunnel private_key_password is needed', () => {
+ let needsSSHTunnelPrivateKeyPassword;
+
+ needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword({
+ _schema: ['Must provide a private key password for the ssh tunnel'],
+ });
+ expect(needsSSHTunnelPrivateKeyPassword).toBe(true);
+
+ needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword(
+ 'Database already exists and `overwrite=true` was not passed',
+ );
+ expect(needsSSHTunnelPrivateKeyPassword).toBe(false);
+
+ needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword({
+ type: ['Must be equal to Database.'],
+ });
+ expect(needsSSHTunnelPrivateKeyPassword).toBe(false);
+});
+
test('identifies error payloads indicating that overwrite confirmation is needed', () => {
let alreadyExists;
@@ -151,6 +280,29 @@ test('extracts DB configuration files that need passwords', () => {
expect(passwordsNeeded).toEqual(['databases/imported_database.yaml']);
});
+test('extracts DB configuration files that need ssh_tunnel passwords', () => {
+ const sshPasswordsNeeded = getSSHPasswordsNeeded(
+ sshTunnelPasswordNeededErrors.errors,
+ );
+ expect(sshPasswordsNeeded).toEqual(['databases/imported_database.yaml']);
+});
+
+test('extracts DB configuration files that need ssh_tunnel private_keys', () => {
+ const sshPrivateKeysNeeded = getSSHPrivateKeysNeeded(
+ sshTunnelPrivateKeyNeededErrors.errors,
+ );
+ expect(sshPrivateKeysNeeded).toEqual(['databases/imported_database.yaml']);
+});
+
+test('extracts DB configuration files that need ssh_tunnel private_key_passwords', () => {
+ const sshPrivateKeyPasswordsNeeded = getSSHPrivateKeyPasswordsNeeded(
+ sshTunnelPrivateKeyPasswordNeededErrors.errors,
+ );
+ expect(sshPrivateKeyPasswordsNeeded).toEqual([
+ 'databases/imported_database.yaml',
+ ]);
+});
+
test('extracts files that need overwrite confirmation', () => {
const alreadyExists = getAlreadyExists(overwriteNeededErrors.errors);
expect(alreadyExists).toEqual(['databases/imported_database.yaml']);
@@ -167,6 +319,17 @@ test('detects if the error message is terminal or if it requires uses interventi
isTerminal = hasTerminalValidation(passwordNeededErrors.errors);
expect(isTerminal).toBe(false);
+
+ isTerminal = hasTerminalValidation(sshTunnelPasswordNeededErrors.errors);
+ expect(isTerminal).toBe(false);
+
+ isTerminal = hasTerminalValidation(sshTunnelPrivateKeyNeededErrors.errors);
+ expect(isTerminal).toBe(false);
+
+ isTerminal = hasTerminalValidation(
+ sshTunnelPrivateKeyPasswordNeededErrors.errors,
+ );
+ expect(isTerminal).toBe(false);
});
test('error message is terminal when the "extra" field contains only the "issue_codes" key', () => {
diff --git a/superset-frontend/src/views/CRUD/utils.tsx b/superset-frontend/src/views/CRUD/utils.tsx
index 190f93dc8b..f12f13e027 100644
--- a/superset-frontend/src/views/CRUD/utils.tsx
+++ b/superset-frontend/src/views/CRUD/utils.tsx
@@ -371,8 +371,34 @@ export /* eslint-disable no-underscore-dangle */
const isNeedsPassword = (payload: any) =>
typeof payload === 'object' &&
Array.isArray(payload._schema) &&
- payload._schema.length === 1 &&
- payload._schema[0] === 'Must provide a password for the database';
+ !!payload._schema?.find(
+ (e: string) => e === 'Must provide a password for the database',
+ );
+
+export /* eslint-disable no-underscore-dangle */
+const isNeedsSSHPassword = (payload: any) =>
+ typeof payload === 'object' &&
+ Array.isArray(payload._schema) &&
+ !!payload._schema?.find(
+ (e: string) => e === 'Must provide a password for the ssh tunnel',
+ );
+
+export /* eslint-disable no-underscore-dangle */
+const isNeedsSSHPrivateKey = (payload: any) =>
+ typeof payload === 'object' &&
+ Array.isArray(payload._schema) &&
+ !!payload._schema?.find(
+ (e: string) => e === 'Must provide a private key for the ssh tunnel',
+ );
+
+export /* eslint-disable no-underscore-dangle */
+const isNeedsSSHPrivateKeyPassword = (payload: any) =>
+ typeof payload === 'object' &&
+ Array.isArray(payload._schema) &&
+ !!payload._schema?.find(
+ (e: string) =>
+ e === 'Must provide a private key password for the ssh tunnel',
+ );
export const isAlreadyExists = (payload: any) =>
typeof payload === 'string' &&
@@ -387,6 +413,35 @@ export const getPasswordsNeeded = (errors: Record[]) =>
)
.flat();
+export const getSSHPasswordsNeeded = (errors: Record[]) =>
+ errors
+ .map(error =>
+ Object.entries(error.extra)
+ .filter(([, payload]) => isNeedsSSHPassword(payload))
+ .map(([fileName]) => fileName),
+ )
+ .flat();
+
+export const getSSHPrivateKeysNeeded = (errors: Record[]) =>
+ errors
+ .map(error =>
+ Object.entries(error.extra)
+ .filter(([, payload]) => isNeedsSSHPrivateKey(payload))
+ .map(([fileName]) => fileName),
+ )
+ .flat();
+
+export const getSSHPrivateKeyPasswordsNeeded = (
+ errors: Record[],
+) =>
+ errors
+ .map(error =>
+ Object.entries(error.extra)
+ .filter(([, payload]) => isNeedsSSHPrivateKeyPassword(payload))
+ .map(([fileName]) => fileName),
+ )
+ .flat();
+
export const getAlreadyExists = (errors: Record[]) =>
errors
.map(error =>
@@ -405,7 +460,12 @@ export const hasTerminalValidation = (errors: Record[]) =>
if (noIssuesCodes.length === 0) return true;
return !noIssuesCodes.every(
- ([, payload]) => isNeedsPassword(payload) || isAlreadyExists(payload),
+ ([, payload]) =>
+ isNeedsPassword(payload) ||
+ isAlreadyExists(payload) ||
+ isNeedsSSHPassword(payload) ||
+ isNeedsSSHPrivateKey(payload) ||
+ isNeedsSSHPrivateKeyPassword(payload),
);
});
diff --git a/superset/charts/api.py b/superset/charts/api.py
index 0a9b61af91..5b453a2d99 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -882,6 +882,30 @@ class ChartRestApi(BaseSupersetModelRestApi):
overwrite:
description: overwrite existing charts?
type: boolean
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Chart import result
@@ -918,9 +942,29 @@ class ChartRestApi(BaseSupersetModelRestApi):
else None
)
overwrite = request.form.get("overwrite") == "true"
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
command = ImportChartsCommand(
- contents, passwords=passwords, overwrite=overwrite
+ contents,
+ passwords=passwords,
+ overwrite=overwrite,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")
diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py
index fc67d1822e..a4208ded41 100644
--- a/superset/commands/importers/v1/__init__.py
+++ b/superset/commands/importers/v1/__init__.py
@@ -47,6 +47,15 @@ class ImportModelsCommand(BaseCommand):
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_passwords") or {}
+ )
+ self.ssh_tunnel_private_keys: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_private_keys") or {}
+ )
+ self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_priv_key_passwords") or {}
+ )
self.overwrite: bool = kwargs.get("overwrite", False)
self._configs: Dict[str, Any] = {}
@@ -88,7 +97,13 @@ class ImportModelsCommand(BaseCommand):
# load the configs and make sure we have confirmation to overwrite existing models
self._configs = load_configs(
- self.contents, self.schemas, self.passwords, exceptions
+ self.contents,
+ self.schemas,
+ self.passwords,
+ exceptions,
+ self.ssh_tunnel_passwords,
+ self.ssh_tunnel_private_keys,
+ self.ssh_tunnel_priv_key_passwords,
)
self._prevent_overwrite_existing_model(exceptions)
diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py
index b63de59536..da4d7808d3 100644
--- a/superset/commands/importers/v1/assets.py
+++ b/superset/commands/importers/v1/assets.py
@@ -68,6 +68,15 @@ class ImportAssetsCommand(BaseCommand):
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_passwords") or {}
+ )
+ self.ssh_tunnel_private_keys: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_private_keys") or {}
+ )
+ self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ kwargs.get("ssh_tunnel_priv_key_passwords") or {}
+ )
self._configs: Dict[str, Any] = {}
@staticmethod
@@ -153,7 +162,13 @@ class ImportAssetsCommand(BaseCommand):
validate_metadata_type(metadata, "assets", exceptions)
self._configs = load_configs(
- self.contents, self.schemas, self.passwords, exceptions
+ self.contents,
+ self.schemas,
+ self.passwords,
+ exceptions,
+ self.ssh_tunnel_passwords,
+ self.ssh_tunnel_private_keys,
+ self.ssh_tunnel_priv_key_passwords,
)
if exceptions:
diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py
index 1e73886682..c8fb97c53d 100644
--- a/superset/commands/importers/v1/utils.py
+++ b/superset/commands/importers/v1/utils.py
@@ -24,6 +24,7 @@ from marshmallow.exceptions import ValidationError
from superset import db
from superset.commands.importers.exceptions import IncorrectVersionError
+from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
METADATA_FILE_NAME = "metadata.yaml"
@@ -93,11 +94,15 @@ def validate_metadata_type(
exceptions.append(exc)
+# pylint: disable=too-many-locals,too-many-arguments
def load_configs(
contents: Dict[str, str],
schemas: Dict[str, Schema],
passwords: Dict[str, str],
exceptions: List[ValidationError],
+ ssh_tunnel_passwords: Dict[str, str],
+ ssh_tunnel_private_keys: Dict[str, str],
+ ssh_tunnel_priv_key_passwords: Dict[str, str],
) -> Dict[str, Any]:
configs: Dict[str, Any] = {}
@@ -106,6 +111,25 @@ def load_configs(
str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all()
}
+ # load existing ssh_tunnels so we can apply the password validation
+ db_ssh_tunnel_passwords: Dict[str, str] = {
+ str(uuid): password
+ for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
+ }
+ # load existing ssh_tunnels so we can apply the private_key validation
+ db_ssh_tunnel_private_keys: Dict[str, str] = {
+ str(uuid): private_key
+ for uuid, private_key in db.session.query(
+ SSHTunnel.uuid, SSHTunnel.private_key
+ ).all()
+ }
+ # load existing ssh_tunnels so we can apply the private_key_password validation
+ db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
+ str(uuid): private_key_password
+ for uuid, private_key_password in db.session.query(
+ SSHTunnel.uuid, SSHTunnel.private_key_password
+ ).all()
+ }
for file_name, content in contents.items():
# skip directories
if not content:
@@ -123,6 +147,42 @@ def load_configs(
elif prefix == "databases" and config["uuid"] in db_passwords:
config["password"] = db_passwords[config["uuid"]]
+ # populate ssh_tunnel_passwords from the request or from existing DBs
+ if file_name in ssh_tunnel_passwords:
+ config["ssh_tunnel"]["password"] = ssh_tunnel_passwords[file_name]
+ elif (
+ prefix == "databases" and config["uuid"] in db_ssh_tunnel_passwords
+ ):
+ config["ssh_tunnel"]["password"] = db_ssh_tunnel_passwords[
+ config["uuid"]
+ ]
+
+ # populate ssh_tunnel_private_keys from the request or from existing DBs
+ if file_name in ssh_tunnel_private_keys:
+ config["ssh_tunnel"]["private_key"] = ssh_tunnel_private_keys[
+ file_name
+ ]
+ elif (
+ prefix == "databases"
+ and config["uuid"] in db_ssh_tunnel_private_keys
+ ):
+ config["ssh_tunnel"]["private_key"] = db_ssh_tunnel_private_keys[
+ config["uuid"]
+ ]
+
+ # populate ssh_tunnel_passwords from the request or from existing DBs
+ if file_name in ssh_tunnel_priv_key_passwords:
+ config["ssh_tunnel"][
+ "private_key_password"
+ ] = ssh_tunnel_priv_key_passwords[file_name]
+ elif (
+ prefix == "databases"
+ and config["uuid"] in db_ssh_tunnel_priv_key_passws
+ ):
+ config["ssh_tunnel"][
+ "private_key_password"
+ ] = db_ssh_tunnel_priv_key_passws[config["uuid"]]
+
schema.load(config)
configs[file_name] = config
except ValidationError as exc:
diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py
index f930da4290..580fc8bc8a 100644
--- a/superset/dashboards/api.py
+++ b/superset/dashboards/api.py
@@ -1035,6 +1035,30 @@ class DashboardRestApi(BaseSupersetModelRestApi):
overwrite:
description: overwrite existing dashboards?
type: boolean
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Dashboard import result
@@ -1074,8 +1098,29 @@ class DashboardRestApi(BaseSupersetModelRestApi):
)
overwrite = request.form.get("overwrite") == "true"
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
+
command = ImportDashboardsCommand(
- contents, passwords=passwords, overwrite=overwrite
+ contents,
+ passwords=passwords,
+ overwrite=overwrite,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")
diff --git a/superset/databases/api.py b/superset/databases/api.py
index a8de0b66d2..5bda161a3a 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -1095,6 +1095,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
overwrite:
description: overwrite existing databases?
type: boolean
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Database import result
@@ -1131,9 +1155,29 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
else None
)
overwrite = request.form.get("overwrite") == "true"
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
command = ImportDatabasesCommand(
- contents, passwords=passwords, overwrite=overwrite
+ contents,
+ passwords=passwords,
+ overwrite=overwrite,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")
diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py
index 4d3bb7f99f..acb794531d 100644
--- a/superset/databases/commands/export.py
+++ b/superset/databases/commands/export.py
@@ -28,6 +28,7 @@ from superset.commands.export.models import ExportModelsCommand
from superset.models.core import Database
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
+from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
@@ -87,6 +88,15 @@ class ExportDatabasesCommand(ExportModelsCommand):
"schemas_allowed_for_file_upload"
)
+ if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
+ ssh_tunnel_payload = ssh_tunnel.export_to_dict(
+ recursive=False,
+ include_parent_ref=False,
+ include_defaults=True,
+ export_uuids=False,
+ )
+ payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
+
payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False)
diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py
index 6704ccd465..c3cc89e08c 100644
--- a/superset/databases/commands/importers/v1/utils.py
+++ b/superset/databases/commands/importers/v1/utils.py
@@ -20,6 +20,7 @@ from typing import Any, Dict
from sqlalchemy.orm import Session
+from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
@@ -42,8 +43,15 @@ def import_database(
# TODO (betodealmeida): move this logic to import_from_dict
config["extra"] = json.dumps(config["extra"])
+ # Before it gets removed in import_from_dict
+ ssh_tunnel = config.pop("ssh_tunnel", None)
+
database = Database.import_from_dict(session, config, recursive=False)
if database.id is None:
session.flush()
+ if ssh_tunnel:
+ ssh_tunnel["database_id"] = database.id
+ SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
+
return database
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index c4f37fa172..234209f173 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -19,7 +19,7 @@
import inspect
import json
-from typing import Any, Dict
+from typing import Any, Dict, List
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -28,9 +28,14 @@ from marshmallow.validate import Length, ValidationError
from marshmallow_enum import EnumField
from sqlalchemy import MetaData
-from superset import db
+from superset import db, is_feature_enabled
from superset.constants import PASSWORD_MASK
from superset.databases.commands.exceptions import DatabaseInvalidError
+from superset.databases.ssh_tunnel.commands.exceptions import (
+ SSHTunnelingNotEnabledError,
+ SSHTunnelInvalidCredentials,
+ SSHTunnelMissingCredentials,
+)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
@@ -706,6 +711,7 @@ class ImportV1DatabaseSchema(Schema):
version = fields.String(required=True)
is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True)
+ ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
@@ -720,6 +726,68 @@ class ImportV1DatabaseSchema(Schema):
if password == PASSWORD_MASK and data.get("password") is None:
raise ValidationError("Must provide a password for the database")
+ @validates_schema
+ def validate_ssh_tunnel_credentials(
+ self, data: Dict[str, Any], **kwargs: Any
+ ) -> None:
+ """If ssh_tunnel has a masked credentials, credentials are required"""
+ uuid = data["uuid"]
+ existing = db.session.query(Database).filter_by(uuid=uuid).first()
+ if existing:
+ return
+
+ # Our DB has a ssh_tunnel in it
+ if ssh_tunnel := data.get("ssh_tunnel"):
+ # Login methods are (only one from these options):
+ # 1. password
+ # 2. private_key + private_key_password
+ # Based on the data passed we determine what info is required.
+ # You cannot mix the credentials from both methods.
+ if not is_feature_enabled("SSH_TUNNELING"):
+ # You are trying to import a Database with SSH Tunnel
+ # But the Feature Flag is not enabled.
+ raise SSHTunnelingNotEnabledError()
+ password = ssh_tunnel.get("password")
+ private_key = ssh_tunnel.get("private_key")
+ private_key_password = ssh_tunnel.get("private_key_password")
+ if password is not None:
+ # Login method #1 (Password)
+ if private_key is not None or private_key_password is not None:
+ # You cannot have a mix of login methods
+ raise SSHTunnelInvalidCredentials()
+ if password == PASSWORD_MASK:
+ raise ValidationError("Must provide a password for the ssh tunnel")
+ if password is None:
+ # If the SSH Tunnel we're importing has no password then it must
+ # have a private_key + private_key_password combination
+ if private_key is None and private_key_password is None:
+ # We have found nothing related to other credentials
+ raise SSHTunnelMissingCredentials()
+ # We need to ask for the missing properties of our method # 2
+ # Some times the property is just missing
+ # or there're times where it's masked.
+ # If both are masked, we need to return a list of errors
+ # so the UI ask for both fields at the same time if needed
+ exception_messages: List[str] = []
+ if private_key is None or private_key == PASSWORD_MASK:
+ # If we get here we need to ask for the private key
+ exception_messages.append(
+ "Must provide a private key for the ssh tunnel"
+ )
+ if (
+ private_key_password is None
+ or private_key_password == PASSWORD_MASK
+ ):
+ # If we get here we need to ask for the private key password
+ exception_messages.append(
+ "Must provide a private key password for the ssh tunnel"
+ )
+ if exception_messages:
+ # We can ask for just one field or both if masked, if both
+ # are empty, SSHTunnelMissingCredentials was already raised
+ raise ValidationError(exception_messages)
+ return
+
class EncryptedField: # pylint: disable=too-few-public-methods
"""
diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py
index 2495961c36..0e3f91cae6 100644
--- a/superset/databases/ssh_tunnel/commands/exceptions.py
+++ b/superset/databases/ssh_tunnel/commands/exceptions.py
@@ -57,3 +57,11 @@ class SSHTunnelRequiredFieldValidationError(ValidationError):
[_("Field is required")],
field_name=field_name,
)
+
+
+class SSHTunnelMissingCredentials(CommandInvalidError):
+ message = _("Must provide credentials for the SSH Tunnel")
+
+
+class SSHTunnelInvalidCredentials(CommandInvalidError):
+ message = _("Cannot have multiple credentials for the SSH Tunnel")
diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py
index 727b6ea467..3384679cb7 100644
--- a/superset/databases/ssh_tunnel/models.py
+++ b/superset/databases/ssh_tunnel/models.py
@@ -68,6 +68,19 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
+ export_fields = [
+ "server_address",
+ "server_port",
+ "username",
+ "password",
+ "private_key",
+ "private_key_password",
+ ]
+
+ extra_import_fields = [
+ "database_id",
+ ]
+
@property
def data(self) -> Dict[str, Any]:
output = {
diff --git a/superset/datasets/api.py b/superset/datasets/api.py
index d58a1dd3f6..48c429d32d 100644
--- a/superset/datasets/api.py
+++ b/superset/datasets/api.py
@@ -830,6 +830,30 @@ class DatasetRestApi(BaseSupersetModelRestApi):
sync_metrics:
description: sync metrics?
type: boolean
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Dataset import result
@@ -870,6 +894,21 @@ class DatasetRestApi(BaseSupersetModelRestApi):
overwrite = request.form.get("overwrite") == "true"
sync_columns = request.form.get("sync_columns") == "true"
sync_metrics = request.form.get("sync_metrics") == "true"
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
command = ImportDatasetsCommand(
contents,
@@ -877,6 +916,9 @@ class DatasetRestApi(BaseSupersetModelRestApi):
overwrite=overwrite,
sync_columns=sync_columns,
sync_metrics=sync_metrics,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")
diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py
index b71a95936a..cc6dad5d25 100644
--- a/superset/datasets/commands/export.py
+++ b/superset/datasets/commands/export.py
@@ -24,10 +24,12 @@ import yaml
from superset.commands.export.models import ExportModelsCommand
from superset.connectors.sqla.models import SqlaTable
+from superset.databases.dao import DatabaseDAO
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.dao import DatasetDAO
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
+from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
@@ -97,6 +99,15 @@ class ExportDatasetsCommand(ExportModelsCommand):
except json.decoder.JSONDecodeError:
logger.info("Unable to decode `extra` field: %s", payload["extra"])
+ if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.database.id):
+ ssh_tunnel_payload = ssh_tunnel.export_to_dict(
+ recursive=False,
+ include_parent_ref=False,
+ include_defaults=True,
+ export_uuids=False,
+ )
+ payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
+
payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False)
diff --git a/superset/importexport/api.py b/superset/importexport/api.py
index 26bc78e5d7..5672f8e3a0 100644
--- a/superset/importexport/api.py
+++ b/superset/importexport/api.py
@@ -122,6 +122,30 @@ class ImportExportRestApi(BaseSupersetApi):
in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Assets import result
@@ -158,7 +182,28 @@ class ImportExportRestApi(BaseSupersetApi):
if "passwords" in request.form
else None
)
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
- command = ImportAssetsCommand(contents, passwords=passwords)
+ command = ImportAssetsCommand(
+ contents,
+ passwords=passwords,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
+ )
command.run()
return self.response(200, message="OK")
diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py
index 69aabf3737..0996ab9b3f 100644
--- a/superset/queries/saved_queries/api.py
+++ b/superset/queries/saved_queries/api.py
@@ -324,6 +324,30 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
overwrite:
description: overwrite existing saved queries?
type: boolean
+ ssh_tunnel_passwords:
+ description: >-
+ JSON map of passwords for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the password should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_password"}`.
+ type: string
+ ssh_tunnel_private_keys:
+ description: >-
+ JSON map of private_keys for each ssh_tunnel associated to a
+ featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key"}`.
+ type: string
+ ssh_tunnel_private_key_passwords:
+ description: >-
+ JSON map of private_key_passwords for each ssh_tunnel associated
+ to a featured database in the ZIP file. If the ZIP includes a
+ ssh_tunnel config in the path `databases/MyDatabase.yaml`,
+ the private_key should be provided in the following format:
+ `{"databases/MyDatabase.yaml": "my_private_key_password"}`.
+ type: string
responses:
200:
description: Saved Query import result
@@ -360,9 +384,29 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
else None
)
overwrite = request.form.get("overwrite") == "true"
+ ssh_tunnel_passwords = (
+ json.loads(request.form["ssh_tunnel_passwords"])
+ if "ssh_tunnel_passwords" in request.form
+ else None
+ )
+ ssh_tunnel_private_keys = (
+ json.loads(request.form["ssh_tunnel_private_keys"])
+ if "ssh_tunnel_private_keys" in request.form
+ else None
+ )
+ ssh_tunnel_priv_key_passwords = (
+ json.loads(request.form["ssh_tunnel_private_key_passwords"])
+ if "ssh_tunnel_private_key_passwords" in request.form
+ else None
+ )
command = ImportSavedQueriesCommand(
- contents, passwords=passwords, overwrite=overwrite
+ contents,
+ passwords=passwords,
+ overwrite=overwrite,
+ ssh_tunnel_passwords=ssh_tunnel_passwords,
+ ssh_tunnel_private_keys=ssh_tunnel_private_keys,
+ ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index b015e4c59b..3859c0be51 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -66,6 +66,11 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config,
database_metadata_config,
dataset_metadata_config,
+ database_with_ssh_tunnel_config_password,
+ database_with_ssh_tunnel_config_private_key,
+ database_with_ssh_tunnel_config_mix_credentials,
+ database_with_ssh_tunnel_config_no_credentials,
+ database_with_ssh_tunnel_config_private_pass_only,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_position,
@@ -2361,6 +2366,449 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(database)
db.session.commit()
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with masked password
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_password.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 422
+ assert response == {
+ "errors": [
+ {
+ "message": "Error importing database",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "databases/imported_database.yaml": {
+ "_schema": ["Must provide a password for the ssh tunnel"]
+ },
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_password_provided(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with masked password provided
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_password.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ "ssh_tunnel_passwords": json.dumps(
+ {"databases/imported_database.yaml": "TEST"}
+ ),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 200
+ assert response == {"message": "OK"}
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert database.database_name == "imported_database"
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == database.id)
+ .one()
+ )
+ self.assertEqual(model_ssh_tunnel.password, "TEST")
+ db.session.delete(database)
+ db.session.commit()
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_private_key_and_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with masked private_key
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 422
+ assert response == {
+ "errors": [
+ {
+ "message": "Error importing database",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "databases/imported_database.yaml": {
+ "_schema": [
+ "Must provide a private key for the ssh tunnel",
+ "Must provide a private key password for the ssh tunnel",
+ ]
+ },
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_private_key_and_password_provided(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with masked password provided
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ "ssh_tunnel_private_keys": json.dumps(
+ {"databases/imported_database.yaml": "TestPrivateKey"}
+ ),
+ "ssh_tunnel_private_key_passwords": json.dumps(
+ {"databases/imported_database.yaml": "TEST"}
+ ),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 200
+ assert response == {"message": "OK"}
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert database.database_name == "imported_database"
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == database.id)
+ .one()
+ )
+ self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
+ self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
+ db.session.delete(database)
+ db.session.commit()
+
+ def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self):
+ """
+ Database API: Test import database with ssh_tunnel and feature flag disabled
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+
+ masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 400
+ assert response == {
+ "errors": [
+ {
+ "message": "SSH Tunneling is not enabled",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_feature_no_credentials(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with ssh_tunnel that has no credentials
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 422
+ assert response == {
+ "errors": [
+ {
+ "message": "Must provide credentials for the SSH Tunnel",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_feature_mix_credentials(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with ssh_tunnel that has no credentials
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 422
+ assert response == {
+ "errors": [
+ {
+ "message": "Cannot have multiple credentials for the SSH Tunnel",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd(
+ self, mock_schema_is_feature_enabled
+ ):
+ """
+ Database API: Test import database with ssh_tunnel that has no credentials
+ """
+ self.login(username="admin")
+ uri = "api/v1/database/import/"
+ mock_schema_is_feature_enabled.return_value = True
+
+ masked_database_config = (
+ database_with_ssh_tunnel_config_private_pass_only.copy()
+ )
+
+ buf = BytesIO()
+ with ZipFile(buf, "w") as bundle:
+ with bundle.open("database_export/metadata.yaml", "w") as fp:
+ fp.write(yaml.safe_dump(database_metadata_config).encode())
+ with bundle.open(
+ "database_export/databases/imported_database.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(masked_database_config).encode())
+ with bundle.open(
+ "database_export/datasets/imported_dataset.yaml", "w"
+ ) as fp:
+ fp.write(yaml.safe_dump(dataset_config).encode())
+ buf.seek(0)
+
+ form_data = {
+ "formData": (buf, "database_export.zip"),
+ }
+ rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 422
+ assert response == {
+ "errors": [
+ {
+ "message": "Error importing database",
+ "error_type": "GENERIC_COMMAND_ERROR",
+ "level": "warning",
+ "extra": {
+ "databases/imported_database.yaml": {
+ "_schema": [
+ "Must provide a private key for the ssh tunnel",
+ "Must provide a private key password for the ssh tunnel",
+ ]
+ },
+ "issue_codes": [
+ {
+ "code": 1010,
+ "message": (
+ "Issue 1010 - Superset encountered an "
+ "error while running a command."
+ ),
+ }
+ ],
+ },
+ }
+ ]
+ }
+
@mock.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_function_names",
)
diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py
index 7e4fcaad78..6a5a9c58bc 100644
--- a/tests/integration_tests/databases/commands_tests.py
+++ b/tests/integration_tests/databases/commands_tests.py
@@ -41,6 +41,7 @@ from superset.databases.commands.tables import TablesDatabaseCommand
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.validate import ValidateDatabaseParametersCommand
from superset.databases.schemas import DatabaseTestConnectionSchema
+from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetErrorsException,
@@ -63,6 +64,11 @@ from tests.integration_tests.fixtures.energy_dashboard import (
from tests.integration_tests.fixtures.importexport import (
database_config,
database_metadata_config,
+ database_with_ssh_tunnel_config_mix_credentials,
+ database_with_ssh_tunnel_config_no_credentials,
+ database_with_ssh_tunnel_config_password,
+ database_with_ssh_tunnel_config_private_key,
+ database_with_ssh_tunnel_config_private_pass_only,
dataset_config,
dataset_metadata_config,
)
@@ -623,6 +629,191 @@ class TestImportDatabasesCommand(SupersetTestCase):
}
}
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_masked_ssh_tunnel_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that database imports with masked ssh_tunnel passwords are rejected"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_password.copy()
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Error importing database"
+ assert excinfo.value.normalized_messages() == {
+ "databases/imported_database.yaml": {
+ "_schema": ["Must provide a password for the ssh tunnel"]
+ }
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_masked_ssh_tunnel_private_key_and_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that database imports with masked ssh_tunnel private_key and private_key_password are rejected"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Error importing database"
+ assert excinfo.value.normalized_messages() == {
+ "databases/imported_database.yaml": {
+ "_schema": [
+ "Must provide a private key for the ssh tunnel",
+ "Must provide a private key password for the ssh tunnel",
+ ]
+ }
+ }
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_with_ssh_tunnel_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that a database with ssh_tunnel password can be imported"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_password.copy()
+ masked_database_config["ssh_tunnel"]["password"] = "TEST"
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ command.run()
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert database.allow_file_upload
+ assert database.allow_ctas
+ assert database.allow_cvas
+ assert database.allow_dml
+ assert not database.allow_run_async
+ assert database.cache_timeout is None
+ assert database.database_name == "imported_database"
+ assert database.expose_in_sqllab
+ assert database.extra == "{}"
+ assert database.sqlalchemy_uri == "sqlite:///test.db"
+
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == database.id)
+ .one()
+ )
+ self.assertEqual(model_ssh_tunnel.password, "TEST")
+
+ db.session.delete(database)
+ db.session.commit()
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_with_ssh_tunnel_private_key_and_password(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that a database with ssh_tunnel private_key and private_key_password can be imported"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
+ masked_database_config["ssh_tunnel"]["private_key"] = "TestPrivateKey"
+ masked_database_config["ssh_tunnel"]["private_key_password"] = "TEST"
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ command.run()
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert database.allow_file_upload
+ assert database.allow_ctas
+ assert database.allow_cvas
+ assert database.allow_dml
+ assert not database.allow_run_async
+ assert database.cache_timeout is None
+ assert database.database_name == "imported_database"
+ assert database.expose_in_sqllab
+ assert database.extra == "{}"
+ assert database.sqlalchemy_uri == "sqlite:///test.db"
+
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == database.id)
+ .one()
+ )
+ self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
+ self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
+
+ db.session.delete(database)
+ db.session.commit()
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_masked_ssh_tunnel_no_credentials(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that databases with ssh_tunnels that have no credentials are rejected"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Must provide credentials for the SSH Tunnel"
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_masked_ssh_tunnel_multiple_credentials(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that databases with ssh_tunnels that have multiple credentials are rejected"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert (
+ str(excinfo.value) == "Cannot have multiple credentials for the SSH Tunnel"
+ )
+
+ @mock.patch("superset.databases.schemas.is_feature_enabled")
+ def test_import_v1_database_masked_ssh_tunnel_only_priv_key_psswd(
+ self, mock_schema_is_feature_enabled
+ ):
+ """Test that databases with ssh_tunnels that have multiple credentials are rejected"""
+ mock_schema_is_feature_enabled.return_value = True
+ masked_database_config = (
+ database_with_ssh_tunnel_config_private_pass_only.copy()
+ )
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Error importing database"
+ assert excinfo.value.normalized_messages() == {
+ "databases/imported_database.yaml": {
+ "_schema": [
+ "Must provide a private key for the ssh tunnel",
+ "Must provide a private key password for the ssh tunnel",
+ ]
+ }
+ }
+
@patch("superset.databases.commands.importers.v1.import_dataset")
def test_import_v1_rollback(self, mock_import_dataset):
"""Test than on an exception everything is rolled back"""
diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py
index b624f3e63c..d5c898eba2 100644
--- a/tests/integration_tests/fixtures/importexport.py
+++ b/tests/integration_tests/fixtures/importexport.py
@@ -361,6 +361,113 @@ database_config: Dict[str, Any] = {
"version": "1.0.0",
}
+database_with_ssh_tunnel_config_private_key: Dict[str, Any] = {
+ "allow_csv_upload": True,
+ "allow_ctas": True,
+ "allow_cvas": True,
+ "allow_dml": True,
+ "allow_run_async": False,
+ "cache_timeout": None,
+ "database_name": "imported_database",
+ "expose_in_sqllab": True,
+ "extra": {},
+ "sqlalchemy_uri": "sqlite:///test.db",
+ "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
+ "ssh_tunnel": {
+ "server_address": "localhost",
+ "server_port": 22,
+ "username": "Test",
+ "private_key": "XXXXXXXXXX",
+ "private_key_password": "XXXXXXXXXX",
+ },
+ "version": "1.0.0",
+}
+
+database_with_ssh_tunnel_config_password: Dict[str, Any] = {
+ "allow_csv_upload": True,
+ "allow_ctas": True,
+ "allow_cvas": True,
+ "allow_dml": True,
+ "allow_run_async": False,
+ "cache_timeout": None,
+ "database_name": "imported_database",
+ "expose_in_sqllab": True,
+ "extra": {},
+ "sqlalchemy_uri": "sqlite:///test.db",
+ "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
+ "ssh_tunnel": {
+ "server_address": "localhost",
+ "server_port": 22,
+ "username": "Test",
+ "password": "XXXXXXXXXX",
+ },
+ "version": "1.0.0",
+}
+
+database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = {
+ "allow_csv_upload": True,
+ "allow_ctas": True,
+ "allow_cvas": True,
+ "allow_dml": True,
+ "allow_run_async": False,
+ "cache_timeout": None,
+ "database_name": "imported_database",
+ "expose_in_sqllab": True,
+ "extra": {},
+ "sqlalchemy_uri": "sqlite:///test.db",
+ "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
+ "ssh_tunnel": {
+ "server_address": "localhost",
+ "server_port": 22,
+ "username": "Test",
+ },
+ "version": "1.0.0",
+}
+
+database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = {
+ "allow_csv_upload": True,
+ "allow_ctas": True,
+ "allow_cvas": True,
+ "allow_dml": True,
+ "allow_run_async": False,
+ "cache_timeout": None,
+ "database_name": "imported_database",
+ "expose_in_sqllab": True,
+ "extra": {},
+ "sqlalchemy_uri": "sqlite:///test.db",
+ "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
+ "ssh_tunnel": {
+ "server_address": "localhost",
+ "server_port": 22,
+ "username": "Test",
+ "password": "XXXXXXXXXX",
+ "private_key": "XXXXXXXXXX",
+ },
+ "version": "1.0.0",
+}
+
+database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = {
+ "allow_csv_upload": True,
+ "allow_ctas": True,
+ "allow_cvas": True,
+ "allow_dml": True,
+ "allow_run_async": False,
+ "cache_timeout": None,
+ "database_name": "imported_database",
+ "expose_in_sqllab": True,
+ "extra": {},
+ "sqlalchemy_uri": "sqlite:///test.db",
+ "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
+ "ssh_tunnel": {
+ "server_address": "localhost",
+ "server_port": 22,
+ "username": "Test",
+ "private_key_password": "XXXXXXXXXX",
+ },
+ "version": "1.0.0",
+}
+
+
dataset_config: Dict[str, Any] = {
"table_name": "imported_dataset",
"main_dttm_col": None,
diff --git a/tests/unit_tests/importexport/api_test.py b/tests/unit_tests/importexport/api_test.py
index a65a682018..86fdd72308 100644
--- a/tests/unit_tests/importexport/api_test.py
+++ b/tests/unit_tests/importexport/api_test.py
@@ -99,7 +99,13 @@ def test_import_assets(
assert response.json == {"message": "OK"}
passwords = {"assets_export/databases/imported_database.yaml": "SECRET"}
- ImportAssetsCommand.assert_called_with(mocked_contents, passwords=passwords)
+ ImportAssetsCommand.assert_called_with(
+ mocked_contents,
+ passwords=passwords,
+ ssh_tunnel_passwords=None,
+ ssh_tunnel_private_keys=None,
+ ssh_tunnel_priv_key_passwords=None,
+ )
def test_import_assets_not_zip(