feat(ssh_tunnel): Import/Export Databases with SSHTunnel credentials (#23099)

This commit is contained in:
Antonio Rivero 2023-02-24 14:36:21 -03:00 committed by GitHub
parent 967383853c
commit 3484e8ea7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2039 additions and 50 deletions

View File

@ -10716,6 +10716,18 @@
"passwords": { "passwords": {
"description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.", "description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string" "type": "string"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"
@ -11439,6 +11451,18 @@
"passwords": { "passwords": {
"description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.", "description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string" "type": "string"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"
@ -13020,6 +13044,18 @@
"passwords": { "passwords": {
"description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.", "description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string" "type": "string"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"
@ -14788,6 +14824,18 @@
"passwords": { "passwords": {
"description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.", "description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string" "type": "string"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"
@ -16231,6 +16279,18 @@
"sync_metrics": { "sync_metrics": {
"description": "sync metrics?", "description": "sync metrics?",
"type": "boolean" "type": "boolean"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"
@ -19428,6 +19488,18 @@
"passwords": { "passwords": {
"description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.", "description": "JSON map of passwords for each featured database in the ZIP file. If the ZIP includes a database config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string" "type": "string"
},
"ssh_tunnel_passwords": {
"description": "JSON map of passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_password\"}`.",
"type": "string"
},
"ssh_tunnel_private_keys": {
"description": "JSON map of private_keys for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key\"}`.",
"type": "string"
},
"ssh_tunnel_private_keyspasswords": {
"description": "JSON map of private_key_passwords for each ssh_tunnel associated to a featured database in the ZIP file. If the ZIP includes a ssh_tunnel config in the path `databases/MyDatabase.yaml`, the private_key_password should be provided in the following format: `{\"databases/MyDatabase.yaml\": \"my_private_key_password\"}`.",
"type": "string"
} }
}, },
"type": "object" "type": "object"

View File

@ -146,4 +146,51 @@ describe('ImportModelsModal', () => {
); );
expect(wrapperWithPasswords.find('input[type="password"]')).toExist(); expect(wrapperWithPasswords.find('input[type="password"]')).toExist();
}); });
it('should render ssh_tunnel password fields when needed for import', () => {
const wrapperWithPasswords = mount(
<ImportModelsModal
{...requiredProps}
sshTunnelPasswordFields={['databases/examples.yaml']}
/>,
{
context: { store },
},
);
expect(
wrapperWithPasswords.find('[data-test="ssh_tunnel_password"]'),
).toExist();
});
it('should render ssh_tunnel private_key fields when needed for import', () => {
const wrapperWithPasswords = mount(
<ImportModelsModal
{...requiredProps}
sshTunnelPrivateKeyFields={['databases/examples.yaml']}
/>,
{
context: { store },
},
);
expect(
wrapperWithPasswords.find('[data-test="ssh_tunnel_private_key"]'),
).toExist();
});
it('should render ssh_tunnel private_key_password fields when needed for import', () => {
const wrapperWithPasswords = mount(
<ImportModelsModal
{...requiredProps}
sshTunnelPrivateKeyPasswordFields={['databases/examples.yaml']}
/>,
{
context: { store },
},
);
expect(
wrapperWithPasswords.find(
'[data-test="ssh_tunnel_private_key_password"]',
),
).toExist();
});
}); });

View File

@ -110,6 +110,14 @@ export interface ImportModelsModalProps {
onHide: () => void; onHide: () => void;
passwordFields?: string[]; passwordFields?: string[];
setPasswordFields?: (passwordFields: string[]) => void; setPasswordFields?: (passwordFields: string[]) => void;
sshTunnelPasswordFields?: string[];
setSSHTunnelPasswordFields?: (sshTunnelPasswordFields: string[]) => void;
sshTunnelPrivateKeyFields?: string[];
setSSHTunnelPrivateKeyFields?: (sshTunnelPrivateKeyFields: string[]) => void;
sshTunnelPrivateKeyPasswordFields?: string[];
setSSHTunnelPrivateKeyPasswordFields?: (
sshTunnelPrivateKeyPasswordFields: string[],
) => void;
} }
const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
@ -122,6 +130,12 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
onHide, onHide,
passwordFields = [], passwordFields = [],
setPasswordFields = () => {}, setPasswordFields = () => {},
sshTunnelPasswordFields = [],
setSSHTunnelPasswordFields = () => {},
sshTunnelPrivateKeyFields = [],
setSSHTunnelPrivateKeyFields = () => {},
sshTunnelPrivateKeyPasswordFields = [],
setSSHTunnelPrivateKeyPasswordFields = () => {},
}) => { }) => {
const [isHidden, setIsHidden] = useState<boolean>(true); const [isHidden, setIsHidden] = useState<boolean>(true);
const [passwords, setPasswords] = useState<Record<string, string>>({}); const [passwords, setPasswords] = useState<Record<string, string>>({});
@ -131,6 +145,14 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
const [fileList, setFileList] = useState<UploadFile[]>([]); const [fileList, setFileList] = useState<UploadFile[]>([]);
const [importingModel, setImportingModel] = useState<boolean>(false); const [importingModel, setImportingModel] = useState<boolean>(false);
const [errorMessage, setErrorMessage] = useState<string>(); const [errorMessage, setErrorMessage] = useState<string>();
const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
Record<string, string>
>({});
const [sshTunnelPrivateKeys, setSSHTunnelPrivateKeys] = useState<
Record<string, string>
>({});
const [sshTunnelPrivateKeyPasswords, setSSHTunnelPrivateKeyPasswords] =
useState<Record<string, string>>({});
const clearModal = () => { const clearModal = () => {
setFileList([]); setFileList([]);
@ -140,6 +162,12 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
setConfirmedOverwrite(false); setConfirmedOverwrite(false);
setImportingModel(false); setImportingModel(false);
setErrorMessage(''); setErrorMessage('');
setSSHTunnelPasswordFields([]);
setSSHTunnelPrivateKeyFields([]);
setSSHTunnelPrivateKeyPasswordFields([]);
setSSHTunnelPasswords({});
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
}; };
const handleErrorMsg = (msg: string) => { const handleErrorMsg = (msg: string) => {
@ -147,7 +175,13 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
}; };
const { const {
state: { alreadyExists, passwordsNeeded }, state: {
alreadyExists,
passwordsNeeded,
sshPasswordNeeded,
sshPrivateKeyNeeded,
sshPrivateKeyPasswordNeeded,
},
importResource, importResource,
} = useImportResource(resourceName, resourceLabel, handleErrorMsg); } = useImportResource(resourceName, resourceLabel, handleErrorMsg);
@ -165,6 +199,27 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
} }
}, [alreadyExists, setNeedsOverwriteConfirm]); }, [alreadyExists, setNeedsOverwriteConfirm]);
useEffect(() => {
setSSHTunnelPasswordFields(sshPasswordNeeded);
if (sshPasswordNeeded.length > 0) {
setImportingModel(false);
}
}, [sshPasswordNeeded, setSSHTunnelPasswordFields]);
useEffect(() => {
setSSHTunnelPrivateKeyFields(sshPrivateKeyNeeded);
if (sshPrivateKeyNeeded.length > 0) {
setImportingModel(false);
}
}, [sshPrivateKeyNeeded, setSSHTunnelPrivateKeyFields]);
useEffect(() => {
setSSHTunnelPrivateKeyPasswordFields(sshPrivateKeyPasswordNeeded);
if (sshPrivateKeyPasswordNeeded.length > 0) {
setImportingModel(false);
}
}, [sshPrivateKeyPasswordNeeded, setSSHTunnelPrivateKeyPasswordFields]);
// Functions // Functions
const hide = () => { const hide = () => {
setIsHidden(true); setIsHidden(true);
@ -181,6 +236,9 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
importResource( importResource(
fileList[0].originFileObj, fileList[0].originFileObj,
passwords, passwords,
sshTunnelPasswords,
sshTunnelPrivateKeys,
sshTunnelPrivateKeyPasswords,
confirmedOverwrite, confirmedOverwrite,
).then(result => { ).then(result => {
if (result) { if (result) {
@ -210,30 +268,117 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
}; };
const renderPasswordFields = () => { const renderPasswordFields = () => {
if (passwordFields.length === 0) { if (
passwordFields.length === 0 &&
sshTunnelPasswordFields.length === 0 &&
sshTunnelPrivateKeyFields.length === 0 &&
sshTunnelPrivateKeyPasswordFields.length === 0
) {
return null; return null;
} }
const files = [
...new Set([
...passwordFields,
...sshTunnelPasswordFields,
...sshTunnelPrivateKeyFields,
...sshTunnelPrivateKeyPasswordFields,
]),
];
return ( return (
<> <>
<h5>{t('Database passwords')}</h5> <h5>{t('Database passwords')}</h5>
<HelperMessage>{passwordsNeededMessage}</HelperMessage> <HelperMessage>{passwordsNeededMessage}</HelperMessage>
{passwordFields.map(fileName => ( {files.map(fileName => (
<StyledInputContainer key={`password-for-${fileName}`}> <>
<div className="control-label"> {passwordFields?.indexOf(fileName) >= 0 && (
{fileName} <StyledInputContainer key={`password-for-${fileName}`}>
<span className="required">*</span> <div className="control-label">
</div> {t('%s PASSWORD', fileName.slice(10))}
<input <span className="required">*</span>
name={`password-${fileName}`} </div>
autoComplete={`password-${fileName}`} <input
type="password" name={`password-${fileName}`}
value={passwords[fileName]} autoComplete={`password-${fileName}`}
onChange={event => type="password"
setPasswords({ ...passwords, [fileName]: event.target.value }) value={passwords[fileName]}
} onChange={event =>
/> setPasswords({
</StyledInputContainer> ...passwords,
[fileName]: event.target.value,
})
}
/>
</StyledInputContainer>
)}
{sshTunnelPasswordFields?.indexOf(fileName) >= 0 && (
<StyledInputContainer key={`ssh_tunnel_password-for-${fileName}`}>
<div className="control-label">
{t('%s SSH TUNNEL PASSWORD', fileName.slice(10))}
<span className="required">*</span>
</div>
<input
name={`ssh_tunnel_password-${fileName}`}
autoComplete={`ssh_tunnel_password-${fileName}`}
type="password"
value={sshTunnelPasswords[fileName]}
onChange={event =>
setSSHTunnelPasswords({
...sshTunnelPasswords,
[fileName]: event.target.value,
})
}
data-test="ssh_tunnel_password"
/>
</StyledInputContainer>
)}
{sshTunnelPrivateKeyFields?.indexOf(fileName) >= 0 && (
<StyledInputContainer
key={`ssh_tunnel_private_key-for-${fileName}`}
>
<div className="control-label">
{t('%s SSH TUNNEL PRIVATE KEY', fileName.slice(10))}
<span className="required">*</span>
</div>
<textarea
name={`ssh_tunnel_private_key-${fileName}`}
autoComplete={`ssh_tunnel_private_key-${fileName}`}
value={sshTunnelPrivateKeys[fileName]}
onChange={event =>
setSSHTunnelPrivateKeys({
...sshTunnelPrivateKeys,
[fileName]: event.target.value,
})
}
data-test="ssh_tunnel_private_key"
/>
</StyledInputContainer>
)}
{sshTunnelPrivateKeyPasswordFields?.indexOf(fileName) >= 0 && (
<StyledInputContainer
key={`ssh_tunnel_private_key_password-for-${fileName}`}
>
<div className="control-label">
{t('%s SSH TUNNEL PRIVATE KEY PASSWORD', fileName.slice(10))}
<span className="required">*</span>
</div>
<input
name={`ssh_tunnel_private_key_password-${fileName}`}
autoComplete={`ssh_tunnel_private_key_password-${fileName}`}
type="password"
value={sshTunnelPrivateKeyPasswords[fileName]}
onChange={event =>
setSSHTunnelPrivateKeyPasswords({
...sshTunnelPrivateKeyPasswords,
[fileName]: event.target.value,
})
}
data-test="ssh_tunnel_private_key_password"
/>
</StyledInputContainer>
)}
</>
))} ))}
</> </>
); );
@ -303,7 +448,12 @@ const ImportModelsModal: FunctionComponent<ImportModelsModalProps> = ({
{errorMessage && ( {errorMessage && (
<ErrorAlert <ErrorAlert
errorMessage={errorMessage} errorMessage={errorMessage}
showDbInstallInstructions={passwordFields.length > 0} showDbInstallInstructions={
passwordFields.length > 0 ||
sshTunnelPasswordFields.length > 0 ||
sshTunnelPrivateKeyFields.length > 0 ||
sshTunnelPrivateKeyPasswordFields.length > 0
}
/> />
)} )}
{renderPasswordFields()} {renderPasswordFields()}

View File

@ -197,6 +197,16 @@ function ChartList(props: ChartListProps) {
const [importingChart, showImportModal] = useState<boolean>(false); const [importingChart, showImportModal] = useState<boolean>(false);
const [passwordFields, setPasswordFields] = useState<string[]>([]); const [passwordFields, setPasswordFields] = useState<string[]>([]);
const [preparingExport, setPreparingExport] = useState<boolean>(false); const [preparingExport, setPreparingExport] = useState<boolean>(false);
const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
string[]
>([]);
const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
string[]
>([]);
const [
sshTunnelPrivateKeyPasswordFields,
setSSHTunnelPrivateKeyPasswordFields,
] = useState<string[]>([]);
// TODO: Fix usage of localStorage keying on the user id // TODO: Fix usage of localStorage keying on the user id
const userSettings = dangerouslyGetItemDoNotUse(userId?.toString(), null) as { const userSettings = dangerouslyGetItemDoNotUse(userId?.toString(), null) as {
@ -888,6 +898,14 @@ function ChartList(props: ChartListProps) {
onHide={closeChartImportModal} onHide={closeChartImportModal}
passwordFields={passwordFields} passwordFields={passwordFields}
setPasswordFields={setPasswordFields} setPasswordFields={setPasswordFields}
sshTunnelPasswordFields={sshTunnelPasswordFields}
setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
setSSHTunnelPrivateKeyPasswordFields={
setSSHTunnelPrivateKeyPasswordFields
}
/> />
{preparingExport && <Loading />} {preparingExport && <Loading />}
</> </>

View File

@ -145,6 +145,16 @@ function DashboardList(props: DashboardListProps) {
const [preparingExport, setPreparingExport] = useState<boolean>(false); const [preparingExport, setPreparingExport] = useState<boolean>(false);
const enableBroadUserAccess = const enableBroadUserAccess =
bootstrapData?.common?.conf?.ENABLE_BROAD_ACTIVITY_ACCESS; bootstrapData?.common?.conf?.ENABLE_BROAD_ACTIVITY_ACCESS;
const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
string[]
>([]);
const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
string[]
>([]);
const [
sshTunnelPrivateKeyPasswordFields,
setSSHTunnelPrivateKeyPasswordFields,
] = useState<string[]>([]);
const openDashboardImportModal = () => { const openDashboardImportModal = () => {
showImportModal(true); showImportModal(true);
@ -789,6 +799,14 @@ function DashboardList(props: DashboardListProps) {
onHide={closeDashboardImportModal} onHide={closeDashboardImportModal}
passwordFields={passwordFields} passwordFields={passwordFields}
setPasswordFields={setPasswordFields} setPasswordFields={setPasswordFields}
sshTunnelPasswordFields={sshTunnelPasswordFields}
setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
setSSHTunnelPrivateKeyPasswordFields={
setSSHTunnelPrivateKeyPasswordFields
}
/> />
{preparingExport && <Loading />} {preparingExport && <Loading />}

View File

@ -555,11 +555,29 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const [isLoading, setLoading] = useState<boolean>(false); const [isLoading, setLoading] = useState<boolean>(false);
const [testInProgress, setTestInProgress] = useState<boolean>(false); const [testInProgress, setTestInProgress] = useState<boolean>(false);
const [passwords, setPasswords] = useState<Record<string, string>>({}); const [passwords, setPasswords] = useState<Record<string, string>>({});
const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
Record<string, string>
>({});
const [sshTunnelPrivateKeys, setSSHTunnelPrivateKeys] = useState<
Record<string, string>
>({});
const [sshTunnelPrivateKeyPasswords, setSSHTunnelPrivateKeyPasswords] =
useState<Record<string, string>>({});
const [confirmedOverwrite, setConfirmedOverwrite] = useState<boolean>(false); const [confirmedOverwrite, setConfirmedOverwrite] = useState<boolean>(false);
const [fileList, setFileList] = useState<UploadFile[]>([]); const [fileList, setFileList] = useState<UploadFile[]>([]);
const [importingModal, setImportingModal] = useState<boolean>(false); const [importingModal, setImportingModal] = useState<boolean>(false);
const [importingErrorMessage, setImportingErrorMessage] = useState<string>(); const [importingErrorMessage, setImportingErrorMessage] = useState<string>();
const [passwordFields, setPasswordFields] = useState<string[]>([]); const [passwordFields, setPasswordFields] = useState<string[]>([]);
const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
string[]
>([]);
const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
string[]
>([]);
const [
sshTunnelPrivateKeyPasswordFields,
setSSHTunnelPrivateKeyPasswordFields,
] = useState<string[]>([]);
const SSHTunnelSwitchComponent = const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch; extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
@ -657,7 +675,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setImportingModal(false); setImportingModal(false);
setImportingErrorMessage(''); setImportingErrorMessage('');
setPasswordFields([]); setPasswordFields([]);
setSSHTunnelPasswordFields([]);
setSSHTunnelPrivateKeyFields([]);
setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({}); setPasswords({});
setSSHTunnelPasswords({});
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
setConfirmedOverwrite(false); setConfirmedOverwrite(false);
setUseSSHTunneling(false); setUseSSHTunneling(false);
onHide(); onHide();
@ -678,6 +702,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
state: { state: {
alreadyExists, alreadyExists,
passwordsNeeded, passwordsNeeded,
sshPasswordNeeded,
sshPrivateKeyNeeded,
sshPrivateKeyPasswordNeeded,
loading: importLoading, loading: importLoading,
failed: importErrored, failed: importErrored,
}, },
@ -811,6 +838,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const dbId = await importResource( const dbId = await importResource(
fileList[0].originFileObj, fileList[0].originFileObj,
passwords, passwords,
sshTunnelPasswords,
sshTunnelPrivateKeys,
sshTunnelPrivateKeyPasswords,
confirmedOverwrite, confirmedOverwrite,
); );
if (dbId) { if (dbId) {
@ -983,7 +1013,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setImportingModal(false); setImportingModal(false);
setImportingErrorMessage(''); setImportingErrorMessage('');
setPasswordFields([]); setPasswordFields([]);
setSSHTunnelPasswordFields([]);
setSSHTunnelPrivateKeyFields([]);
setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({}); setPasswords({});
setSSHTunnelPasswords({});
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
} }
setDB({ type: ActionType.reset }); setDB({ type: ActionType.reset });
setFileList([]); setFileList([]);
@ -993,7 +1029,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
if ( if (
importLoading || importLoading ||
(alreadyExists.length && !confirmedOverwrite) || (alreadyExists.length && !confirmedOverwrite) ||
(passwordsNeeded.length && JSON.stringify(passwords) === '{}') (passwordsNeeded.length && JSON.stringify(passwords) === '{}') ||
(sshPasswordNeeded.length &&
JSON.stringify(sshTunnelPasswords) === '{}') ||
(sshPrivateKeyNeeded.length &&
JSON.stringify(sshTunnelPrivateKeys) === '{}') ||
(sshPrivateKeyPasswordNeeded.length &&
JSON.stringify(sshTunnelPrivateKeyPasswords) === '{}')
) )
return true; return true;
return false; return false;
@ -1098,13 +1140,24 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
!importLoading && !importLoading &&
!alreadyExists.length && !alreadyExists.length &&
!passwordsNeeded.length && !passwordsNeeded.length &&
!sshPasswordNeeded.length &&
!sshPrivateKeyNeeded.length &&
!sshPrivateKeyPasswordNeeded.length &&
!isLoading && // This prevents a double toast for non-related imports !isLoading && // This prevents a double toast for non-related imports
!importErrored // This prevents a success toast on error !importErrored // This prevents a success toast on error
) { ) {
onClose(); onClose();
addSuccessToast(t('Database connected')); addSuccessToast(t('Database connected'));
} }
}, [alreadyExists, passwordsNeeded, importLoading, importErrored]); }, [
alreadyExists,
passwordsNeeded,
importLoading,
importErrored,
sshPasswordNeeded,
sshPrivateKeyNeeded,
sshPrivateKeyPasswordNeeded,
]);
useEffect(() => { useEffect(() => {
if (show) { if (show) {
@ -1153,6 +1206,18 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setPasswordFields([...passwordsNeeded]); setPasswordFields([...passwordsNeeded]);
}, [passwordsNeeded]); }, [passwordsNeeded]);
useEffect(() => {
setSSHTunnelPasswordFields([...sshPasswordNeeded]);
}, [sshPasswordNeeded]);
useEffect(() => {
setSSHTunnelPrivateKeyFields([...sshPrivateKeyNeeded]);
}, [sshPrivateKeyNeeded]);
useEffect(() => {
setSSHTunnelPrivateKeyPasswordFields([...sshPrivateKeyPasswordNeeded]);
}, [sshPrivateKeyPasswordNeeded]);
useEffect(() => { useEffect(() => {
if (db && isSSHTunneling) { if (db && isSSHTunneling) {
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel)); setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
@ -1162,7 +1227,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const onDbImport = async (info: UploadChangeParam) => { const onDbImport = async (info: UploadChangeParam) => {
setImportingErrorMessage(''); setImportingErrorMessage('');
setPasswordFields([]); setPasswordFields([]);
setSSHTunnelPasswordFields([]);
setSSHTunnelPrivateKeyFields([]);
setSSHTunnelPrivateKeyPasswordFields([]);
setPasswords({}); setPasswords({});
setSSHTunnelPasswords({});
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
setImportingModal(true); setImportingModal(true);
setFileList([ setFileList([
{ {
@ -1175,15 +1246,33 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const dbId = await importResource( const dbId = await importResource(
info.file.originFileObj, info.file.originFileObj,
passwords, passwords,
sshTunnelPasswords,
sshTunnelPrivateKeys,
sshTunnelPrivateKeyPasswords,
confirmedOverwrite, confirmedOverwrite,
); );
if (dbId) onDatabaseAdd?.(); if (dbId) onDatabaseAdd?.();
}; };
const passwordNeededField = () => { const passwordNeededField = () => {
if (!passwordFields.length) return null; if (
!passwordFields.length &&
!sshTunnelPasswordFields.length &&
!sshTunnelPrivateKeyFields.length &&
!sshTunnelPrivateKeyPasswordFields.length
)
return null;
return passwordFields.map(database => ( const files = [
...new Set([
...passwordFields,
...sshTunnelPasswordFields,
...sshTunnelPrivateKeyFields,
...sshTunnelPrivateKeyPasswordFields,
]),
];
return files.map(database => (
<> <>
<StyledAlertMargin> <StyledAlertMargin>
<Alert <Alert
@ -1197,19 +1286,77 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
)} )}
/> />
</StyledAlertMargin> </StyledAlertMargin>
<ValidatedInput {passwordFields?.indexOf(database) >= 0 && (
id="password_needed" <ValidatedInput
name="password_needed" id="password_needed"
required name="password_needed"
value={passwords[database]} required
onChange={(event: React.ChangeEvent<HTMLInputElement>) => value={passwords[database]}
setPasswords({ ...passwords, [database]: event.target.value }) onChange={(event: React.ChangeEvent<HTMLInputElement>) =>
} setPasswords({ ...passwords, [database]: event.target.value })
validationMethods={{ onBlur: () => {} }} }
errorMessage={validationErrors?.password_needed} validationMethods={{ onBlur: () => {} }}
label={t('%s PASSWORD', database.slice(10))} errorMessage={validationErrors?.password_needed}
css={formScrollableStyles} label={t('%s PASSWORD', database.slice(10))}
/> css={formScrollableStyles}
/>
)}
{sshTunnelPasswordFields?.indexOf(database) >= 0 && (
<ValidatedInput
id="ssh_tunnel_password_needed"
name="ssh_tunnel_password_needed"
required
value={sshTunnelPasswords[database]}
onChange={(event: React.ChangeEvent<HTMLInputElement>) =>
setSSHTunnelPasswords({
...sshTunnelPasswords,
[database]: event.target.value,
})
}
validationMethods={{ onBlur: () => {} }}
errorMessage={validationErrors?.ssh_tunnel_password_needed}
label={t('%s SSH TUNNEL PASSWORD', database.slice(10))}
css={formScrollableStyles}
/>
)}
{sshTunnelPrivateKeyFields?.indexOf(database) >= 0 && (
<ValidatedInput
id="ssh_tunnel_private_key_needed"
name="ssh_tunnel_private_key_needed"
required
value={sshTunnelPrivateKeys[database]}
onChange={(event: React.ChangeEvent<HTMLInputElement>) =>
setSSHTunnelPrivateKeys({
...sshTunnelPrivateKeys,
[database]: event.target.value,
})
}
validationMethods={{ onBlur: () => {} }}
errorMessage={validationErrors?.ssh_tunnel_private_key_needed}
label={t('%s SSH TUNNEL PRIVATE KEY', database.slice(10))}
css={formScrollableStyles}
/>
)}
{sshTunnelPrivateKeyPasswordFields?.indexOf(database) >= 0 && (
<ValidatedInput
id="ssh_tunnel_private_key_password_needed"
name="ssh_tunnel_private_key_password_needed"
required
value={sshTunnelPrivateKeyPasswords[database]}
onChange={(event: React.ChangeEvent<HTMLInputElement>) =>
setSSHTunnelPrivateKeyPasswords({
...sshTunnelPrivateKeyPasswords,
[database]: event.target.value,
})
}
validationMethods={{ onBlur: () => {} }}
errorMessage={
validationErrors?.ssh_tunnel_private_key_password_needed
}
label={t('%s SSH TUNNEL PRIVATE KEY PASSWORD', database.slice(10))}
css={formScrollableStyles}
/>
)}
</> </>
)); ));
}; };
@ -1468,7 +1615,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
); );
}; };
if (fileList.length > 0 && (alreadyExists.length || passwordFields.length)) { if (
fileList.length > 0 &&
(alreadyExists.length ||
passwordFields.length ||
sshTunnelPasswordFields.length ||
sshTunnelPrivateKeyFields.length ||
sshTunnelPrivateKeyPasswordFields.length)
) {
return ( return (
<Modal <Modal
css={(theme: SupersetTheme) => [ css={(theme: SupersetTheme) => [

View File

@ -163,6 +163,16 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
const [importingDataset, showImportModal] = useState<boolean>(false); const [importingDataset, showImportModal] = useState<boolean>(false);
const [passwordFields, setPasswordFields] = useState<string[]>([]); const [passwordFields, setPasswordFields] = useState<string[]>([]);
const [preparingExport, setPreparingExport] = useState<boolean>(false); const [preparingExport, setPreparingExport] = useState<boolean>(false);
const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
string[]
>([]);
const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
string[]
>([]);
const [
sshTunnelPrivateKeyPasswordFields,
setSSHTunnelPrivateKeyPasswordFields,
] = useState<string[]>([]);
const openDatasetImportModal = () => { const openDatasetImportModal = () => {
showImportModal(true); showImportModal(true);
@ -822,6 +832,14 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
onHide={closeDatasetImportModal} onHide={closeDatasetImportModal}
passwordFields={passwordFields} passwordFields={passwordFields}
setPasswordFields={setPasswordFields} setPasswordFields={setPasswordFields}
sshTunnelPasswordFields={sshTunnelPasswordFields}
setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
setSSHTunnelPrivateKeyPasswordFields={
setSSHTunnelPrivateKeyPasswordFields
}
/> />
{preparingExport && <Loading />} {preparingExport && <Loading />}
</> </>

View File

@ -115,6 +115,16 @@ function SavedQueryList({
const [importingSavedQuery, showImportModal] = useState<boolean>(false); const [importingSavedQuery, showImportModal] = useState<boolean>(false);
const [passwordFields, setPasswordFields] = useState<string[]>([]); const [passwordFields, setPasswordFields] = useState<string[]>([]);
const [preparingExport, setPreparingExport] = useState<boolean>(false); const [preparingExport, setPreparingExport] = useState<boolean>(false);
const [sshTunnelPasswordFields, setSSHTunnelPasswordFields] = useState<
string[]
>([]);
const [sshTunnelPrivateKeyFields, setSSHTunnelPrivateKeyFields] = useState<
string[]
>([]);
const [
sshTunnelPrivateKeyPasswordFields,
setSSHTunnelPrivateKeyPasswordFields,
] = useState<string[]>([]);
const openSavedQueryImportModal = () => { const openSavedQueryImportModal = () => {
showImportModal(true); showImportModal(true);
@ -577,6 +587,14 @@ function SavedQueryList({
onHide={closeSavedQueryImportModal} onHide={closeSavedQueryImportModal}
passwordFields={passwordFields} passwordFields={passwordFields}
setPasswordFields={setPasswordFields} setPasswordFields={setPasswordFields}
sshTunnelPasswordFields={sshTunnelPasswordFields}
setSSHTunnelPasswordFields={setSSHTunnelPasswordFields}
sshTunnelPrivateKeyFields={sshTunnelPrivateKeyFields}
setSSHTunnelPrivateKeyFields={setSSHTunnelPrivateKeyFields}
sshTunnelPrivateKeyPasswordFields={sshTunnelPrivateKeyPasswordFields}
setSSHTunnelPrivateKeyPasswordFields={
setSSHTunnelPrivateKeyPasswordFields
}
/> />
{preparingExport && <Loading />} {preparingExport && <Loading />}
</> </>

View File

@ -25,6 +25,9 @@ import {
getAlreadyExists, getAlreadyExists,
getPasswordsNeeded, getPasswordsNeeded,
hasTerminalValidation, hasTerminalValidation,
getSSHPasswordsNeeded,
getSSHPrivateKeysNeeded,
getSSHPrivateKeyPasswordsNeeded,
} from 'src/views/CRUD/utils'; } from 'src/views/CRUD/utils';
import { FetchDataConfig } from 'src/components/ListView'; import { FetchDataConfig } from 'src/components/ListView';
import { FilterValue } from 'src/components/ListView/types'; import { FilterValue } from 'src/components/ListView/types';
@ -386,6 +389,9 @@ interface ImportResourceState {
loading: boolean; loading: boolean;
passwordsNeeded: string[]; passwordsNeeded: string[];
alreadyExists: string[]; alreadyExists: string[];
sshPasswordNeeded: string[];
sshPrivateKeyNeeded: string[];
sshPrivateKeyPasswordNeeded: string[];
failed: boolean; failed: boolean;
} }
@ -398,6 +404,9 @@ export function useImportResource(
loading: false, loading: false,
passwordsNeeded: [], passwordsNeeded: [],
alreadyExists: [], alreadyExists: [],
sshPasswordNeeded: [],
sshPrivateKeyNeeded: [],
sshPrivateKeyPasswordNeeded: [],
failed: false, failed: false,
}); });
@ -409,6 +418,9 @@ export function useImportResource(
( (
bundle: File, bundle: File,
databasePasswords: Record<string, string> = {}, databasePasswords: Record<string, string> = {},
sshTunnelPasswords: Record<string, string> = {},
sshTunnelPrivateKey: Record<string, string> = {},
sshTunnelPrivateKeyPasswords: Record<string, string> = {},
overwrite = false, overwrite = false,
) => { ) => {
// Set loading state // Set loading state
@ -436,6 +448,33 @@ export function useImportResource(
if (overwrite) { if (overwrite) {
formData.append('overwrite', 'true'); formData.append('overwrite', 'true');
} }
/* The import bundle may contain ssh tunnel passwords; if required
* they should be provided by the user during import.
*/
if (sshTunnelPasswords) {
formData.append(
'ssh_tunnel_passwords',
JSON.stringify(sshTunnelPasswords),
);
}
/* The import bundle may contain ssh tunnel private_key; if required
* they should be provided by the user during import.
*/
if (sshTunnelPrivateKey) {
formData.append(
'ssh_tunnel_private_keys',
JSON.stringify(sshTunnelPrivateKey),
);
}
/* The import bundle may contain ssh tunnel private_key_password; if required
* they should be provided by the user during import.
*/
if (sshTunnelPrivateKeyPasswords) {
formData.append(
'ssh_tunnel_private_key_passwords',
JSON.stringify(sshTunnelPrivateKeyPasswords),
);
}
return SupersetClient.post({ return SupersetClient.post({
endpoint: `/api/v1/${resourceName}/import/`, endpoint: `/api/v1/${resourceName}/import/`,
@ -446,6 +485,9 @@ export function useImportResource(
updateState({ updateState({
passwordsNeeded: [], passwordsNeeded: [],
alreadyExists: [], alreadyExists: [],
sshPasswordNeeded: [],
sshPrivateKeyNeeded: [],
sshPrivateKeyPasswordNeeded: [],
failed: false, failed: false,
}); });
return true; return true;
@ -479,6 +521,11 @@ export function useImportResource(
} else { } else {
updateState({ updateState({
passwordsNeeded: getPasswordsNeeded(error.errors), passwordsNeeded: getPasswordsNeeded(error.errors),
sshPasswordNeeded: getSSHPasswordsNeeded(error.errors),
sshPrivateKeyNeeded: getSSHPrivateKeysNeeded(error.errors),
sshPrivateKeyPasswordNeeded: getSSHPrivateKeyPasswordsNeeded(
error.errors,
),
alreadyExists: getAlreadyExists(error.errors), alreadyExists: getAlreadyExists(error.errors),
}); });
} }

View File

@ -22,9 +22,15 @@ import {
getAlreadyExists, getAlreadyExists,
getFilterValues, getFilterValues,
getPasswordsNeeded, getPasswordsNeeded,
getSSHPasswordsNeeded,
getSSHPrivateKeysNeeded,
getSSHPrivateKeyPasswordsNeeded,
hasTerminalValidation, hasTerminalValidation,
isAlreadyExists, isAlreadyExists,
isNeedsPassword, isNeedsPassword,
isNeedsSSHPassword,
isNeedsSSHPrivateKey,
isNeedsSSHPrivateKeyPassword,
} from 'src/views/CRUD/utils'; } from 'src/views/CRUD/utils';
import { User } from 'src/types/bootstrapTypes'; import { User } from 'src/types/bootstrapTypes';
import { Filter, TableTab } from './types'; import { Filter, TableTab } from './types';
@ -112,6 +118,72 @@ const passwordNeededErrors = {
], ],
}; };
const sshTunnelPasswordNeededErrors = {
errors: [
{
message: 'Error importing database',
error_type: 'GENERIC_COMMAND_ERROR',
level: 'warning',
extra: {
'databases/imported_database.yaml': {
_schema: ['Must provide a password for the ssh tunnel'],
},
issue_codes: [
{
code: 1010,
message:
'Issue 1010 - Superset encountered an error while running a command.',
},
],
},
},
],
};
const sshTunnelPrivateKeyNeededErrors = {
errors: [
{
message: 'Error importing database',
error_type: 'GENERIC_COMMAND_ERROR',
level: 'warning',
extra: {
'databases/imported_database.yaml': {
_schema: ['Must provide a private key for the ssh tunnel'],
},
issue_codes: [
{
code: 1010,
message:
'Issue 1010 - Superset encountered an error while running a command.',
},
],
},
},
],
};
const sshTunnelPrivateKeyPasswordNeededErrors = {
errors: [
{
message: 'Error importing database',
error_type: 'GENERIC_COMMAND_ERROR',
level: 'warning',
extra: {
'databases/imported_database.yaml': {
_schema: ['Must provide a private key password for the ssh tunnel'],
},
issue_codes: [
{
code: 1010,
message:
'Issue 1010 - Superset encountered an error while running a command.',
},
],
},
},
],
};
test('identifies error payloads indicating that password is needed', () => { test('identifies error payloads indicating that password is needed', () => {
let needsPassword; let needsPassword;
@ -129,6 +201,63 @@ test('identifies error payloads indicating that password is needed', () => {
expect(needsPassword).toBe(false); expect(needsPassword).toBe(false);
}); });
test('identifies error payloads indicating that ssh_tunnel password is needed', () => {
let needsSSHTunnelPassword;
needsSSHTunnelPassword = isNeedsSSHPassword({
_schema: ['Must provide a password for the ssh tunnel'],
});
expect(needsSSHTunnelPassword).toBe(true);
needsSSHTunnelPassword = isNeedsSSHPassword(
'Database already exists and `overwrite=true` was not passed',
);
expect(needsSSHTunnelPassword).toBe(false);
needsSSHTunnelPassword = isNeedsSSHPassword({
type: ['Must be equal to Database.'],
});
expect(needsSSHTunnelPassword).toBe(false);
});
test('identifies error payloads indicating that ssh_tunnel private_key is needed', () => {
let needsSSHTunnelPrivateKey;
needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey({
_schema: ['Must provide a private key for the ssh tunnel'],
});
expect(needsSSHTunnelPrivateKey).toBe(true);
needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey(
'Database already exists and `overwrite=true` was not passed',
);
expect(needsSSHTunnelPrivateKey).toBe(false);
needsSSHTunnelPrivateKey = isNeedsSSHPrivateKey({
type: ['Must be equal to Database.'],
});
expect(needsSSHTunnelPrivateKey).toBe(false);
});
test('identifies error payloads indicating that ssh_tunnel private_key_password is needed', () => {
let needsSSHTunnelPrivateKeyPassword;
needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword({
_schema: ['Must provide a private key password for the ssh tunnel'],
});
expect(needsSSHTunnelPrivateKeyPassword).toBe(true);
needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword(
'Database already exists and `overwrite=true` was not passed',
);
expect(needsSSHTunnelPrivateKeyPassword).toBe(false);
needsSSHTunnelPrivateKeyPassword = isNeedsSSHPrivateKeyPassword({
type: ['Must be equal to Database.'],
});
expect(needsSSHTunnelPrivateKeyPassword).toBe(false);
});
test('identifies error payloads indicating that overwrite confirmation is needed', () => { test('identifies error payloads indicating that overwrite confirmation is needed', () => {
let alreadyExists; let alreadyExists;
@ -151,6 +280,29 @@ test('extracts DB configuration files that need passwords', () => {
expect(passwordsNeeded).toEqual(['databases/imported_database.yaml']); expect(passwordsNeeded).toEqual(['databases/imported_database.yaml']);
}); });
test('extracts DB configuration files that need ssh_tunnel passwords', () => {
const sshPasswordsNeeded = getSSHPasswordsNeeded(
sshTunnelPasswordNeededErrors.errors,
);
expect(sshPasswordsNeeded).toEqual(['databases/imported_database.yaml']);
});
test('extracts DB configuration files that need ssh_tunnel private_keys', () => {
const sshPrivateKeysNeeded = getSSHPrivateKeysNeeded(
sshTunnelPrivateKeyNeededErrors.errors,
);
expect(sshPrivateKeysNeeded).toEqual(['databases/imported_database.yaml']);
});
test('extracts DB configuration files that need ssh_tunnel private_key_passwords', () => {
const sshPrivateKeyPasswordsNeeded = getSSHPrivateKeyPasswordsNeeded(
sshTunnelPrivateKeyPasswordNeededErrors.errors,
);
expect(sshPrivateKeyPasswordsNeeded).toEqual([
'databases/imported_database.yaml',
]);
});
test('extracts files that need overwrite confirmation', () => { test('extracts files that need overwrite confirmation', () => {
const alreadyExists = getAlreadyExists(overwriteNeededErrors.errors); const alreadyExists = getAlreadyExists(overwriteNeededErrors.errors);
expect(alreadyExists).toEqual(['databases/imported_database.yaml']); expect(alreadyExists).toEqual(['databases/imported_database.yaml']);
@ -167,6 +319,17 @@ test('detects if the error message is terminal or if it requires uses interventi
isTerminal = hasTerminalValidation(passwordNeededErrors.errors); isTerminal = hasTerminalValidation(passwordNeededErrors.errors);
expect(isTerminal).toBe(false); expect(isTerminal).toBe(false);
isTerminal = hasTerminalValidation(sshTunnelPasswordNeededErrors.errors);
expect(isTerminal).toBe(false);
isTerminal = hasTerminalValidation(sshTunnelPrivateKeyNeededErrors.errors);
expect(isTerminal).toBe(false);
isTerminal = hasTerminalValidation(
sshTunnelPrivateKeyPasswordNeededErrors.errors,
);
expect(isTerminal).toBe(false);
}); });
test('error message is terminal when the "extra" field contains only the "issue_codes" key', () => { test('error message is terminal when the "extra" field contains only the "issue_codes" key', () => {

View File

@ -371,8 +371,34 @@ export /* eslint-disable no-underscore-dangle */
const isNeedsPassword = (payload: any) => const isNeedsPassword = (payload: any) =>
typeof payload === 'object' && typeof payload === 'object' &&
Array.isArray(payload._schema) && Array.isArray(payload._schema) &&
payload._schema.length === 1 && !!payload._schema?.find(
payload._schema[0] === 'Must provide a password for the database'; (e: string) => e === 'Must provide a password for the database',
);
export /* eslint-disable no-underscore-dangle */
const isNeedsSSHPassword = (payload: any) =>
typeof payload === 'object' &&
Array.isArray(payload._schema) &&
!!payload._schema?.find(
(e: string) => e === 'Must provide a password for the ssh tunnel',
);
export /* eslint-disable no-underscore-dangle */
const isNeedsSSHPrivateKey = (payload: any) =>
typeof payload === 'object' &&
Array.isArray(payload._schema) &&
!!payload._schema?.find(
(e: string) => e === 'Must provide a private key for the ssh tunnel',
);
export /* eslint-disable no-underscore-dangle */
const isNeedsSSHPrivateKeyPassword = (payload: any) =>
typeof payload === 'object' &&
Array.isArray(payload._schema) &&
!!payload._schema?.find(
(e: string) =>
e === 'Must provide a private key password for the ssh tunnel',
);
export const isAlreadyExists = (payload: any) => export const isAlreadyExists = (payload: any) =>
typeof payload === 'string' && typeof payload === 'string' &&
@ -387,6 +413,35 @@ export const getPasswordsNeeded = (errors: Record<string, any>[]) =>
) )
.flat(); .flat();
export const getSSHPasswordsNeeded = (errors: Record<string, any>[]) =>
errors
.map(error =>
Object.entries(error.extra)
.filter(([, payload]) => isNeedsSSHPassword(payload))
.map(([fileName]) => fileName),
)
.flat();
export const getSSHPrivateKeysNeeded = (errors: Record<string, any>[]) =>
errors
.map(error =>
Object.entries(error.extra)
.filter(([, payload]) => isNeedsSSHPrivateKey(payload))
.map(([fileName]) => fileName),
)
.flat();
export const getSSHPrivateKeyPasswordsNeeded = (
errors: Record<string, any>[],
) =>
errors
.map(error =>
Object.entries(error.extra)
.filter(([, payload]) => isNeedsSSHPrivateKeyPassword(payload))
.map(([fileName]) => fileName),
)
.flat();
export const getAlreadyExists = (errors: Record<string, any>[]) => export const getAlreadyExists = (errors: Record<string, any>[]) =>
errors errors
.map(error => .map(error =>
@ -405,7 +460,12 @@ export const hasTerminalValidation = (errors: Record<string, any>[]) =>
if (noIssuesCodes.length === 0) return true; if (noIssuesCodes.length === 0) return true;
return !noIssuesCodes.every( return !noIssuesCodes.every(
([, payload]) => isNeedsPassword(payload) || isAlreadyExists(payload), ([, payload]) =>
isNeedsPassword(payload) ||
isAlreadyExists(payload) ||
isNeedsSSHPassword(payload) ||
isNeedsSSHPrivateKey(payload) ||
isNeedsSSHPrivateKeyPassword(payload),
); );
}); });

View File

@ -882,6 +882,30 @@ class ChartRestApi(BaseSupersetModelRestApi):
overwrite: overwrite:
description: overwrite existing charts? description: overwrite existing charts?
type: boolean type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Chart import result description: Chart import result
@ -918,9 +942,29 @@ class ChartRestApi(BaseSupersetModelRestApi):
else None else None
) )
overwrite = request.form.get("overwrite") == "true" overwrite = request.form.get("overwrite") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportChartsCommand( command = ImportChartsCommand(
contents, passwords=passwords, overwrite=overwrite contents,
passwords=passwords,
overwrite=overwrite,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
) )
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -47,6 +47,15 @@ class ImportModelsCommand(BaseCommand):
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {} self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: Dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
self.ssh_tunnel_private_keys: Dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self.overwrite: bool = kwargs.get("overwrite", False) self.overwrite: bool = kwargs.get("overwrite", False)
self._configs: Dict[str, Any] = {} self._configs: Dict[str, Any] = {}
@ -88,7 +97,13 @@ class ImportModelsCommand(BaseCommand):
# load the configs and make sure we have confirmation to overwrite existing models # load the configs and make sure we have confirmation to overwrite existing models
self._configs = load_configs( self._configs = load_configs(
self.contents, self.schemas, self.passwords, exceptions self.contents,
self.schemas,
self.passwords,
exceptions,
self.ssh_tunnel_passwords,
self.ssh_tunnel_private_keys,
self.ssh_tunnel_priv_key_passwords,
) )
self._prevent_overwrite_existing_model(exceptions) self._prevent_overwrite_existing_model(exceptions)

View File

@ -68,6 +68,15 @@ class ImportAssetsCommand(BaseCommand):
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {} self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: Dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
self.ssh_tunnel_private_keys: Dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self._configs: Dict[str, Any] = {} self._configs: Dict[str, Any] = {}
@staticmethod @staticmethod
@ -153,7 +162,13 @@ class ImportAssetsCommand(BaseCommand):
validate_metadata_type(metadata, "assets", exceptions) validate_metadata_type(metadata, "assets", exceptions)
self._configs = load_configs( self._configs = load_configs(
self.contents, self.schemas, self.passwords, exceptions self.contents,
self.schemas,
self.passwords,
exceptions,
self.ssh_tunnel_passwords,
self.ssh_tunnel_private_keys,
self.ssh_tunnel_priv_key_passwords,
) )
if exceptions: if exceptions:

View File

@ -24,6 +24,7 @@ from marshmallow.exceptions import ValidationError
from superset import db from superset import db
from superset.commands.importers.exceptions import IncorrectVersionError from superset.commands.importers.exceptions import IncorrectVersionError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database from superset.models.core import Database
METADATA_FILE_NAME = "metadata.yaml" METADATA_FILE_NAME = "metadata.yaml"
@ -93,11 +94,15 @@ def validate_metadata_type(
exceptions.append(exc) exceptions.append(exc)
# pylint: disable=too-many-locals,too-many-arguments
def load_configs( def load_configs(
contents: Dict[str, str], contents: Dict[str, str],
schemas: Dict[str, Schema], schemas: Dict[str, Schema],
passwords: Dict[str, str], passwords: Dict[str, str],
exceptions: List[ValidationError], exceptions: List[ValidationError],
ssh_tunnel_passwords: Dict[str, str],
ssh_tunnel_private_keys: Dict[str, str],
ssh_tunnel_priv_key_passwords: Dict[str, str],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
configs: Dict[str, Any] = {} configs: Dict[str, Any] = {}
@ -106,6 +111,25 @@ def load_configs(
str(uuid): password str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all() for uuid, password in db.session.query(Database.uuid, Database.password).all()
} }
# load existing ssh_tunnels so we can apply the password validation
db_ssh_tunnel_passwords: Dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
}
# load existing ssh_tunnels so we can apply the private_key validation
db_ssh_tunnel_private_keys: Dict[str, str] = {
str(uuid): private_key
for uuid, private_key in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key
).all()
}
# load existing ssh_tunnels so we can apply the private_key_password validation
db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
str(uuid): private_key_password
for uuid, private_key_password in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key_password
).all()
}
for file_name, content in contents.items(): for file_name, content in contents.items():
# skip directories # skip directories
if not content: if not content:
@ -123,6 +147,42 @@ def load_configs(
elif prefix == "databases" and config["uuid"] in db_passwords: elif prefix == "databases" and config["uuid"] in db_passwords:
config["password"] = db_passwords[config["uuid"]] config["password"] = db_passwords[config["uuid"]]
# populate ssh_tunnel_passwords from the request or from existing DBs
if file_name in ssh_tunnel_passwords:
config["ssh_tunnel"]["password"] = ssh_tunnel_passwords[file_name]
elif (
prefix == "databases" and config["uuid"] in db_ssh_tunnel_passwords
):
config["ssh_tunnel"]["password"] = db_ssh_tunnel_passwords[
config["uuid"]
]
# populate ssh_tunnel_private_keys from the request or from existing DBs
if file_name in ssh_tunnel_private_keys:
config["ssh_tunnel"]["private_key"] = ssh_tunnel_private_keys[
file_name
]
elif (
prefix == "databases"
and config["uuid"] in db_ssh_tunnel_private_keys
):
config["ssh_tunnel"]["private_key"] = db_ssh_tunnel_private_keys[
config["uuid"]
]
# populate ssh_tunnel_passwords from the request or from existing DBs
if file_name in ssh_tunnel_priv_key_passwords:
config["ssh_tunnel"][
"private_key_password"
] = ssh_tunnel_priv_key_passwords[file_name]
elif (
prefix == "databases"
and config["uuid"] in db_ssh_tunnel_priv_key_passws
):
config["ssh_tunnel"][
"private_key_password"
] = db_ssh_tunnel_priv_key_passws[config["uuid"]]
schema.load(config) schema.load(config)
configs[file_name] = config configs[file_name] = config
except ValidationError as exc: except ValidationError as exc:

View File

@ -1035,6 +1035,30 @@ class DashboardRestApi(BaseSupersetModelRestApi):
overwrite: overwrite:
description: overwrite existing dashboards? description: overwrite existing dashboards?
type: boolean type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Dashboard import result description: Dashboard import result
@ -1074,8 +1098,29 @@ class DashboardRestApi(BaseSupersetModelRestApi):
) )
overwrite = request.form.get("overwrite") == "true" overwrite = request.form.get("overwrite") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportDashboardsCommand( command = ImportDashboardsCommand(
contents, passwords=passwords, overwrite=overwrite contents,
passwords=passwords,
overwrite=overwrite,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
) )
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -1095,6 +1095,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
overwrite: overwrite:
description: overwrite existing databases? description: overwrite existing databases?
type: boolean type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Database import result description: Database import result
@ -1131,9 +1155,29 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
else None else None
) )
overwrite = request.form.get("overwrite") == "true" overwrite = request.form.get("overwrite") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportDatabasesCommand( command = ImportDatabasesCommand(
contents, passwords=passwords, overwrite=overwrite contents,
passwords=passwords,
overwrite=overwrite,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
) )
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -28,6 +28,7 @@ from superset.commands.export.models import ExportModelsCommand
from superset.models.core import Database from superset.models.core import Database
from superset.utils.dict_import_export import EXPORT_VERSION from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename from superset.utils.file import get_filename
from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -87,6 +88,15 @@ class ExportDatabasesCommand(ExportModelsCommand):
"schemas_allowed_for_file_upload" "schemas_allowed_for_file_upload"
) )
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=False,
)
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
payload["version"] = EXPORT_VERSION payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)

View File

@ -20,6 +20,7 @@ from typing import Any, Dict
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database from superset.models.core import Database
@ -42,8 +43,15 @@ def import_database(
# TODO (betodealmeida): move this logic to import_from_dict # TODO (betodealmeida): move this logic to import_from_dict
config["extra"] = json.dumps(config["extra"]) config["extra"] = json.dumps(config["extra"])
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
database = Database.import_from_dict(session, config, recursive=False) database = Database.import_from_dict(session, config, recursive=False)
if database.id is None: if database.id is None:
session.flush() session.flush()
if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
return database return database

View File

@ -19,7 +19,7 @@
import inspect import inspect
import json import json
from typing import Any, Dict from typing import Any, Dict, List
from flask import current_app from flask import current_app
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
@ -28,9 +28,14 @@ from marshmallow.validate import Length, ValidationError
from marshmallow_enum import EnumField from marshmallow_enum import EnumField
from sqlalchemy import MetaData from sqlalchemy import MetaData
from superset import db from superset import db, is_feature_enabled
from superset.constants import PASSWORD_MASK from superset.constants import PASSWORD_MASK
from superset.databases.commands.exceptions import DatabaseInvalidError from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelingNotEnabledError,
SSHTunnelInvalidCredentials,
SSHTunnelMissingCredentials,
)
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException from superset.exceptions import CertificateException, SupersetSecurityException
@ -706,6 +711,7 @@ class ImportV1DatabaseSchema(Schema):
version = fields.String(required=True) version = fields.String(required=True)
is_managed_externally = fields.Boolean(allow_none=True, default=False) is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True) external_url = fields.String(allow_none=True)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema @validates_schema
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None: def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
@ -720,6 +726,68 @@ class ImportV1DatabaseSchema(Schema):
if password == PASSWORD_MASK and data.get("password") is None: if password == PASSWORD_MASK and data.get("password") is None:
raise ValidationError("Must provide a password for the database") raise ValidationError("Must provide a password for the database")
@validates_schema
def validate_ssh_tunnel_credentials(
self, data: Dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
if existing:
return
# Our DB has a ssh_tunnel in it
if ssh_tunnel := data.get("ssh_tunnel"):
# Login methods are (only one from these options):
# 1. password
# 2. private_key + private_key_password
# Based on the data passed we determine what info is required.
# You cannot mix the credentials from both methods.
if not is_feature_enabled("SSH_TUNNELING"):
# You are trying to import a Database with SSH Tunnel
# But the Feature Flag is not enabled.
raise SSHTunnelingNotEnabledError()
password = ssh_tunnel.get("password")
private_key = ssh_tunnel.get("private_key")
private_key_password = ssh_tunnel.get("private_key_password")
if password is not None:
# Login method #1 (Password)
if private_key is not None or private_key_password is not None:
# You cannot have a mix of login methods
raise SSHTunnelInvalidCredentials()
if password == PASSWORD_MASK:
raise ValidationError("Must provide a password for the ssh tunnel")
if password is None:
# If the SSH Tunnel we're importing has no password then it must
# have a private_key + private_key_password combination
if private_key is None and private_key_password is None:
# We have found nothing related to other credentials
raise SSHTunnelMissingCredentials()
# We need to ask for the missing properties of our method # 2
# Some times the property is just missing
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
exception_messages: List[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
"Must provide a private key for the ssh tunnel"
)
if (
private_key_password is None
or private_key_password == PASSWORD_MASK
):
# If we get here we need to ask for the private key password
exception_messages.append(
"Must provide a private key password for the ssh tunnel"
)
if exception_messages:
# We can ask for just one field or both if masked, if both
# are empty, SSHTunnelMissingCredentials was already raised
raise ValidationError(exception_messages)
return
class EncryptedField: # pylint: disable=too-few-public-methods class EncryptedField: # pylint: disable=too-few-public-methods
""" """

View File

@ -57,3 +57,11 @@ class SSHTunnelRequiredFieldValidationError(ValidationError):
[_("Field is required")], [_("Field is required")],
field_name=field_name, field_name=field_name,
) )
class SSHTunnelMissingCredentials(CommandInvalidError):
message = _("Must provide credentials for the SSH Tunnel")
class SSHTunnelInvalidCredentials(CommandInvalidError):
message = _("Cannot have multiple credentials for the SSH Tunnel")

View File

@ -68,6 +68,19 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
) )
export_fields = [
"server_address",
"server_port",
"username",
"password",
"private_key",
"private_key_password",
]
extra_import_fields = [
"database_id",
]
@property @property
def data(self) -> Dict[str, Any]: def data(self) -> Dict[str, Any]:
output = { output = {

View File

@ -830,6 +830,30 @@ class DatasetRestApi(BaseSupersetModelRestApi):
sync_metrics: sync_metrics:
description: sync metrics? description: sync metrics?
type: boolean type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Dataset import result description: Dataset import result
@ -870,6 +894,21 @@ class DatasetRestApi(BaseSupersetModelRestApi):
overwrite = request.form.get("overwrite") == "true" overwrite = request.form.get("overwrite") == "true"
sync_columns = request.form.get("sync_columns") == "true" sync_columns = request.form.get("sync_columns") == "true"
sync_metrics = request.form.get("sync_metrics") == "true" sync_metrics = request.form.get("sync_metrics") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportDatasetsCommand( command = ImportDatasetsCommand(
contents, contents,
@ -877,6 +916,9 @@ class DatasetRestApi(BaseSupersetModelRestApi):
overwrite=overwrite, overwrite=overwrite,
sync_columns=sync_columns, sync_columns=sync_columns,
sync_metrics=sync_metrics, sync_metrics=sync_metrics,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
) )
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -24,10 +24,12 @@ import yaml
from superset.commands.export.models import ExportModelsCommand from superset.commands.export.models import ExportModelsCommand
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.databases.dao import DatabaseDAO
from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.utils.dict_import_export import EXPORT_VERSION from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename from superset.utils.file import get_filename
from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,6 +99,15 @@ class ExportDatasetsCommand(ExportModelsCommand):
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
logger.info("Unable to decode `extra` field: %s", payload["extra"]) logger.info("Unable to decode `extra` field: %s", payload["extra"])
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.database.id):
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=False,
)
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
payload["version"] = EXPORT_VERSION payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)

View File

@ -122,6 +122,30 @@ class ImportExportRestApi(BaseSupersetApi):
in the following format: in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`. `{"databases/MyDatabase.yaml": "my_password"}`.
type: string type: string
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Assets import result description: Assets import result
@ -158,7 +182,28 @@ class ImportExportRestApi(BaseSupersetApi):
if "passwords" in request.form if "passwords" in request.form
else None else None
) )
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportAssetsCommand(contents, passwords=passwords) command = ImportAssetsCommand(
contents,
passwords=passwords,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -324,6 +324,30 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
overwrite: overwrite:
description: overwrite existing saved queries? description: overwrite existing saved queries?
type: boolean type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses: responses:
200: 200:
description: Saved Query import result description: Saved Query import result
@ -360,9 +384,29 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
else None else None
) )
overwrite = request.form.get("overwrite") == "true" overwrite = request.form.get("overwrite") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportSavedQueriesCommand( command = ImportSavedQueriesCommand(
contents, passwords=passwords, overwrite=overwrite contents,
passwords=passwords,
overwrite=overwrite,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
) )
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -66,6 +66,11 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config, dataset_config,
database_metadata_config, database_metadata_config,
dataset_metadata_config, dataset_metadata_config,
database_with_ssh_tunnel_config_password,
database_with_ssh_tunnel_config_private_key,
database_with_ssh_tunnel_config_mix_credentials,
database_with_ssh_tunnel_config_no_credentials,
database_with_ssh_tunnel_config_private_pass_only,
) )
from tests.integration_tests.fixtures.unicode_dashboard import ( from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_position, load_unicode_dashboard_with_position,
@ -2361,6 +2366,449 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(database) db.session.delete(database)
db.session.commit() db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_password(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with masked password
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": {
"_schema": ["Must provide a password for the ssh tunnel"]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_password_provided(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with masked password provided
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
"ssh_tunnel_passwords": json.dumps(
{"databases/imported_database.yaml": "TEST"}
),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.password, "TEST")
db.session.delete(database)
db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_private_key_and_password(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with masked private_key
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_private_key_and_password_provided(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with masked password provided
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
"ssh_tunnel_private_keys": json.dumps(
{"databases/imported_database.yaml": "TestPrivateKey"}
),
"ssh_tunnel_private_key_passwords": json.dumps(
{"databases/imported_database.yaml": "TEST"}
),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
db.session.delete(database)
db.session.commit()
def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self):
"""
Database API: Test import database with ssh_tunnel and feature flag disabled
"""
self.login(username="admin")
uri = "api/v1/database/import/"
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {
"errors": [
{
"message": "SSH Tunneling is not enabled",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_no_credentials(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Must provide credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_mix_credentials(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Cannot have multiple credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd(
self, mock_schema_is_feature_enabled
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(username="admin")
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = (
database_with_ssh_tunnel_config_private_pass_only.copy()
)
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch( @mock.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_function_names", "superset.db_engine_specs.base.BaseEngineSpec.get_function_names",
) )

View File

@ -41,6 +41,7 @@ from superset.databases.commands.tables import TablesDatabaseCommand
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.validate import ValidateDatabaseParametersCommand from superset.databases.commands.validate import ValidateDatabaseParametersCommand
from superset.databases.schemas import DatabaseTestConnectionSchema from superset.databases.schemas import DatabaseTestConnectionSchema
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
SupersetErrorsException, SupersetErrorsException,
@ -63,6 +64,11 @@ from tests.integration_tests.fixtures.energy_dashboard import (
from tests.integration_tests.fixtures.importexport import ( from tests.integration_tests.fixtures.importexport import (
database_config, database_config,
database_metadata_config, database_metadata_config,
database_with_ssh_tunnel_config_mix_credentials,
database_with_ssh_tunnel_config_no_credentials,
database_with_ssh_tunnel_config_password,
database_with_ssh_tunnel_config_private_key,
database_with_ssh_tunnel_config_private_pass_only,
dataset_config, dataset_config,
dataset_metadata_config, dataset_metadata_config,
) )
@ -623,6 +629,191 @@ class TestImportDatabasesCommand(SupersetTestCase):
} }
} }
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_masked_ssh_tunnel_password(
self, mock_schema_is_feature_enabled
):
"""Test that database imports with masked ssh_tunnel passwords are rejected"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing database"
assert excinfo.value.normalized_messages() == {
"databases/imported_database.yaml": {
"_schema": ["Must provide a password for the ssh tunnel"]
}
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_masked_ssh_tunnel_private_key_and_password(
self, mock_schema_is_feature_enabled
):
"""Test that database imports with masked ssh_tunnel private_key and private_key_password are rejected"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing database"
assert excinfo.value.normalized_messages() == {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
}
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_with_ssh_tunnel_password(
self, mock_schema_is_feature_enabled
):
"""Test that a database with ssh_tunnel password can be imported"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
masked_database_config["ssh_tunnel"]["password"] = "TEST"
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.allow_file_upload
assert database.allow_ctas
assert database.allow_cvas
assert database.allow_dml
assert not database.allow_run_async
assert database.cache_timeout is None
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.password, "TEST")
db.session.delete(database)
db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_with_ssh_tunnel_private_key_and_password(
self, mock_schema_is_feature_enabled
):
"""Test that a database with ssh_tunnel private_key and private_key_password can be imported"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
masked_database_config["ssh_tunnel"]["private_key"] = "TestPrivateKey"
masked_database_config["ssh_tunnel"]["private_key_password"] = "TEST"
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.allow_file_upload
assert database.allow_ctas
assert database.allow_cvas
assert database.allow_dml
assert not database.allow_run_async
assert database.cache_timeout is None
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
db.session.delete(database)
db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_masked_ssh_tunnel_no_credentials(
self, mock_schema_is_feature_enabled
):
"""Test that databases with ssh_tunnels that have no credentials are rejected"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Must provide credentials for the SSH Tunnel"
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_masked_ssh_tunnel_multiple_credentials(
self, mock_schema_is_feature_enabled
):
"""Test that databases with ssh_tunnels that have multiple credentials are rejected"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert (
str(excinfo.value) == "Cannot have multiple credentials for the SSH Tunnel"
)
@mock.patch("superset.databases.schemas.is_feature_enabled")
def test_import_v1_database_masked_ssh_tunnel_only_priv_key_psswd(
self, mock_schema_is_feature_enabled
):
"""Test that databases with ssh_tunnels that have multiple credentials are rejected"""
mock_schema_is_feature_enabled.return_value = True
masked_database_config = (
database_with_ssh_tunnel_config_private_pass_only.copy()
)
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(masked_database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing database"
assert excinfo.value.normalized_messages() == {
"databases/imported_database.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel",
]
}
}
@patch("superset.databases.commands.importers.v1.import_dataset") @patch("superset.databases.commands.importers.v1.import_dataset")
def test_import_v1_rollback(self, mock_import_dataset): def test_import_v1_rollback(self, mock_import_dataset):
"""Test than on an exception everything is rolled back""" """Test than on an exception everything is rolled back"""

View File

@ -361,6 +361,113 @@ database_config: Dict[str, Any] = {
"version": "1.0.0", "version": "1.0.0",
} }
database_with_ssh_tunnel_config_private_key: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
"server_port": 22,
"username": "Test",
"private_key": "XXXXXXXXXX",
"private_key_password": "XXXXXXXXXX",
},
"version": "1.0.0",
}
database_with_ssh_tunnel_config_password: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
"server_port": 22,
"username": "Test",
"password": "XXXXXXXXXX",
},
"version": "1.0.0",
}
database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
"server_port": 22,
"username": "Test",
},
"version": "1.0.0",
}
database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
"server_port": 22,
"username": "Test",
"password": "XXXXXXXXXX",
"private_key": "XXXXXXXXXX",
},
"version": "1.0.0",
}
database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
"server_port": 22,
"username": "Test",
"private_key_password": "XXXXXXXXXX",
},
"version": "1.0.0",
}
dataset_config: Dict[str, Any] = { dataset_config: Dict[str, Any] = {
"table_name": "imported_dataset", "table_name": "imported_dataset",
"main_dttm_col": None, "main_dttm_col": None,

View File

@ -99,7 +99,13 @@ def test_import_assets(
assert response.json == {"message": "OK"} assert response.json == {"message": "OK"}
passwords = {"assets_export/databases/imported_database.yaml": "SECRET"} passwords = {"assets_export/databases/imported_database.yaml": "SECRET"}
ImportAssetsCommand.assert_called_with(mocked_contents, passwords=passwords) ImportAssetsCommand.assert_called_with(
mocked_contents,
passwords=passwords,
ssh_tunnel_passwords=None,
ssh_tunnel_private_keys=None,
ssh_tunnel_priv_key_passwords=None,
)
def test_import_assets_not_zip( def test_import_assets_not_zip(