fix: SSH Tunnel configuration settings (#27186)

This commit is contained in:
Geido 2024-03-11 16:56:54 +01:00 committed by GitHub
parent fde93dcf08
commit 89e89de341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 871 additions and 271 deletions

View File

@ -44,15 +44,15 @@ interface MenuObjectChildProps {
disable?: boolean;
}
export interface SwitchProps {
isEditMode: boolean;
dbFetched: any;
disableSSHTunnelingForEngine?: boolean;
useSSHTunneling: boolean;
setUseSSHTunneling: React.Dispatch<React.SetStateAction<boolean>>;
setDB: React.Dispatch<any>;
isSSHTunneling: boolean;
}
// loose typing to avoid any circular dependencies
// refer to SSHTunnelSwitch component for strict typing
type SwitchProps = {
db: object;
changeMethods: {
onParametersChange: (event: any) => void;
};
clearValidationErrors: () => void;
};
type ConfigDetailsProps = {
embeddedId: string;

View File

@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => {
useRedux: true,
});
userEvent.click(screen.getByTestId('schedule-panel'));
const days = screen.getAllByTitle(/day/i, { exact: true });
expect(days.length).toBe(2);
const day = screen.getByText('day');
expect(day).toBeInTheDocument();
});
// Notification Method Section

View File

@ -17,12 +17,11 @@
* under the License.
*/
import React from 'react';
import { isEmpty } from 'lodash';
import { SupersetTheme, t } from '@superset-ui/core';
import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import { FieldPropTypes } from '.';
import { FieldPropTypes } from '../../types';
import { toggleStyle, infoTooltip } from '../styles';
export const hostField = ({
@ -252,35 +251,3 @@ export const forceSSLField = ({
/>
</div>
);
export const SSHTunnelSwitch = ({
isEditMode,
changeMethods,
clearValidationErrors,
db,
}: FieldPropTypes) => (
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
<AntdSwitch
disabled={isEditMode && !isEmpty(db?.ssh_tunnel)}
checked={db?.parameters?.ssh}
onChange={changed => {
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: changed,
},
});
clearValidationErrors();
}}
data-test="ssh-tunnel-switch"
/>
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
<InfoTooltip
tooltip={t('SSH Tunnel configuration parameters')}
placement="right"
viewBox="0 -5 24 24"
/>
</div>
);

View File

@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons';
import { FieldPropTypes } from '.';
import { FieldPropTypes } from '../../types';
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
enum CredentialInfoOptions {

View File

@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons';
import { FieldPropTypes } from '.';
import { StyledFooterButton, StyledCatalogTable } from '../styles';
import { CatalogObject } from '../../types';
import { CatalogObject, FieldPropTypes } from '../../types';
export const TableCatalog = ({
required,

View File

@ -19,7 +19,7 @@
import React from 'react';
import { t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import { FieldPropTypes } from '.';
import { FieldPropTypes } from '../../types';
const FIELD_TEXT_MAP = {
account: {

View File

@ -17,7 +17,11 @@
* under the License.
*/
import React, { FormEvent } from 'react';
import { SupersetTheme, JsonObject } from '@superset-ui/core';
import {
SupersetTheme,
JsonObject,
getExtensionsRegistry,
} from '@superset-ui/core';
import { InputProps } from 'antd/lib/input';
import { Form } from 'src/components/Form';
import {
@ -31,13 +35,13 @@ import {
portField,
queryField,
usernameField,
SSHTunnelSwitch,
} from './CommonParameters';
import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField';
import { TableCatalog } from './TableCatalog';
import { formScrollableStyles, validatedFormStyles } from '../styles';
import { DatabaseForm, DatabaseObject } from '../../types';
import SSHTunnelSwitch from '../SSHTunnelSwitch';
export const FormFieldOrder = [
'host',
@ -59,34 +63,10 @@ export const FormFieldOrder = [
'ssh',
];
export interface FieldPropTypes {
required: boolean;
hasTooltip?: boolean;
tooltipText?: (value: any) => string;
placeholder?: string;
onParametersChange: (value: any) => string;
onParametersUploadFileChange: (value: any) => string;
changeMethods: { onParametersChange: (value: any) => string } & {
onChange: (value: any) => string;
} & {
onQueryChange: (value: any) => string;
} & { onParametersUploadFileChange: (value: any) => string } & {
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: (value: any) => string;
};
validationErrors: JsonObject | null;
getValidation: () => void;
clearValidationErrors: () => void;
db?: DatabaseObject;
field: string;
isEditMode?: boolean;
sslForced?: boolean;
defaultDBName?: string;
editNewDb?: boolean;
}
const extensionsRegistry = getExtensionsRegistry();
const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
const FORM_FIELD_MAP = {
host: hostField,
@ -105,7 +85,7 @@ const FORM_FIELD_MAP = {
warehouse: validatedInputField,
role: validatedInputField,
account: validatedInputField,
ssh: SSHTunnelSwitch,
ssh: SSHTunnelSwitchComponent,
};
interface DatabaseConnectionFormProps {
@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps {
}
const DatabaseConnectionForm = ({
dbModel: { parameters },
dbModel,
db,
editNewDb,
getPlaceholder,
@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({
sslForced,
validationErrors,
clearValidationErrors,
}: DatabaseConnectionFormProps) => (
<Form>
<div
// @ts-ignore
css={(theme: SupersetTheme) => [
formScrollableStyles,
validatedFormStyles(theme),
]}
>
{parameters &&
FormFieldOrder.filter(
(key: string) =>
Object.keys(parameters.properties).includes(key) ||
key === 'database_name',
).map(field =>
FORM_FIELD_MAP[field]({
required: parameters.required?.includes(field),
changeMethods: {
onParametersChange,
onChange,
onQueryChange,
onParametersUploadFileChange,
onAddTableCatalog,
onRemoveTableCatalog,
onExtraInputChange,
},
validationErrors,
getValidation,
clearValidationErrors,
db,
key: field,
field,
isEditMode,
sslForced,
editNewDb,
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
}),
)}
</div>
</Form>
);
}: DatabaseConnectionFormProps) => {
const parameters = dbModel?.parameters;
return (
<Form>
<div
// @ts-ignore
css={(theme: SupersetTheme) => [
formScrollableStyles,
validatedFormStyles(theme),
]}
>
{parameters &&
FormFieldOrder.filter(
(key: string) =>
Object.keys(parameters.properties).includes(key) ||
key === 'database_name',
).map(field =>
FORM_FIELD_MAP[field]({
required: parameters.required?.includes(field),
changeMethods: {
onParametersChange,
onChange,
onQueryChange,
onParametersUploadFileChange,
onAddTableCatalog,
onRemoveTableCatalog,
onExtraInputChange,
},
validationErrors,
getValidation,
clearValidationErrors,
db,
key: field,
field,
isEditMode,
sslForced,
editNewDb,
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
}),
)}
</div>
</Form>
);
};
export const FormFieldMap = FORM_FIELD_MAP;
export default DatabaseConnectionForm;

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import React, { EventHandler, ChangeEvent, useState } from 'react';
import React, { useState } from 'react';
import { t, styled } from '@superset-ui/core';
import { AntdForm, Col, Row } from 'src/components';
import { Form, FormLabel } from 'src/components/Form';
@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio';
import { Input, TextArea } from 'src/components/Input';
import { Input as AntdInput, Tooltip } from 'antd';
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
import { DatabaseObject } from '../types';
import { DatabaseObject, FieldPropTypes } from '../types';
import { AuthType } from '.';
const StyledDiv = styled.div`
@ -54,9 +54,7 @@ const SSHTunnelForm = ({
setSSHTunnelLoginMethod,
}: {
db: DatabaseObject | null;
onSSHTunnelParametersChange: EventHandler<
ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
>;
onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange'];
setSSHTunnelLoginMethod: (method: AuthType) => void;
}) => {
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
@ -86,9 +84,9 @@ const SSHTunnelForm = ({
</FormLabel>
<Input
name="server_port"
type="text"
placeholder={t('22')}
value={db?.ssh_tunnel?.server_port || ''}
type="number"
value={db?.ssh_tunnel?.server_port}
onChange={onSSHTunnelParametersChange}
data-test="ssh-tunnel-server_port-input"
/>

View File

@ -0,0 +1,162 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import React from 'react';
import { render, screen } from 'spec/helpers/testing-library';
import userEvent from '@testing-library/user-event';
import SSHTunnelSwitch from './SSHTunnelSwitch';
import { DatabaseForm, DatabaseObject } from '../types';
jest.mock('@superset-ui/core', () => ({
...jest.requireActual('@superset-ui/core'),
isFeatureEnabled: jest.fn().mockReturnValue(true),
}));
jest.mock('src/components', () => ({
AntdSwitch: ({
checked,
onChange,
}: {
checked: boolean;
onChange: (checked: boolean) => void;
}) => (
<button
onClick={() => onChange(!checked)}
aria-checked={checked}
role="switch"
type="button"
>
{checked ? 'ON' : 'OFF'}
</button>
),
}));
const mockChangeMethods = {
onParametersChange: jest.fn(),
};
const mockDbModel = {
engine: 'mysql',
engine_information: {
disable_ssh_tunneling: false,
},
} as DatabaseForm;
const defaultDb = {
parameters: { ssh: false },
ssh_tunnel: {},
engine: 'mysql',
} as DatabaseObject;
afterEach(() => {
jest.clearAllMocks();
});
test('Renders SSH Tunnel switch enabled by default and toggles its state', () => {
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={mockDbModel}
/>,
);
const switchButton = screen.getByRole('switch');
expect(switchButton).toHaveTextContent('OFF');
userEvent.click(switchButton);
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
});
expect(switchButton).toHaveTextContent('ON');
});
test('Does not render if SSH Tunnel is disabled', () => {
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={{
...mockDbModel,
engine_information: {
disable_ssh_tunneling: true,
supports_file_upload: false,
},
}}
/>,
);
expect(screen.queryByRole('switch')).not.toBeInTheDocument();
});
test('Checks the switch based on db.parameters.ssh', () => {
const dbWithSSHTunnelEnabled = {
...defaultDb,
parameters: { ssh: true },
} as DatabaseObject;
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={dbWithSSHTunnelEnabled}
dbModel={mockDbModel}
/>,
);
expect(screen.getByRole('switch')).toHaveTextContent('ON');
});
test('Calls onParametersChange with true if SSH Tunnel info exists', () => {
const dbWithSSHTunnelInfo = {
...defaultDb,
parameters: { ssh: undefined },
ssh_tunnel: { host: 'example.com' },
} as DatabaseObject;
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={dbWithSSHTunnelInfo}
dbModel={mockDbModel}
/>,
);
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
});
});
test('Displays tooltip text on hover over the InfoTooltip', async () => {
const tooltipText = 'SSH Tunnel configuration parameters';
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={mockDbModel}
/>,
);
const infoTooltipTrigger = screen.getByRole('img', {
name: 'info-solid_small',
});
expect(infoTooltipTrigger).toBeInTheDocument();
userEvent.hover(infoTooltipTrigger);
const tooltip = await screen.findByText(tooltipText);
expect(tooltip).toBeInTheDocument();
});

View File

@ -16,35 +16,73 @@
* specific language governing permissions and limitations
* under the License.
*/
import React from 'react';
import { t, SupersetTheme, SwitchProps } from '@superset-ui/core';
import React, { useEffect, useState } from 'react';
import {
t,
SupersetTheme,
isFeatureEnabled,
FeatureFlag,
} from '@superset-ui/core';
import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import { isEmpty } from 'lodash';
import { ActionType } from '.';
import { infoTooltip, toggleStyle } from './styles';
import { SwitchProps } from '../types';
const SSHTunnelSwitch = ({
isEditMode,
dbFetched,
useSSHTunneling,
setUseSSHTunneling,
setDB,
isSSHTunneling,
}: SwitchProps) =>
isSSHTunneling ? (
clearValidationErrors,
changeMethods,
db,
dbModel,
}: SwitchProps) => {
const [isChecked, setChecked] = useState(false);
const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling);
const disableSSHTunnelingForEngine =
dbModel?.engine_information?.disable_ssh_tunneling || false;
const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine;
const handleOnChange = (changed: boolean) => {
setChecked(changed);
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: changed,
},
});
clearValidationErrors();
};
useEffect(() => {
if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) {
setChecked(db.parameters.ssh);
}
}, [db?.parameters?.ssh, isSSHTunnelEnabled]);
useEffect(() => {
if (
isSSHTunnelEnabled &&
db?.parameters?.ssh === undefined &&
!isEmpty(db?.ssh_tunnel)
) {
// reflecting the state of the ssh tunnel on first load
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: true,
},
});
}
}, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]);
return isSSHTunnelEnabled ? (
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
<AntdSwitch
disabled={isEditMode && !isEmpty(dbFetched?.ssh_tunnel)}
checked={useSSHTunneling}
onChange={changed => {
setUseSSHTunneling(changed);
if (!changed) {
setDB({
type: ActionType.RemoveSSHTunnelConfig,
});
}
}}
checked={isChecked}
onChange={handleOnChange}
data-test="ssh-tunnel-switch"
/>
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({
/>
</div>
) : null;
};
export default SSHTunnelSwitch;

View File

@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
// TODO: These tests should be made atomic in separate files
import React from 'react';
import fetchMock from 'fetch-mock';
import userEvent from '@testing-library/user-event';
@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input',
);
expect(SSHTunnelServerPortInput).toHaveValue('');
expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22');
expect(SSHTunnelServerPortInput).toHaveValue('22');
expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input',
);
@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input',
);
expect(SSHTunnelServerPortInput).toHaveValue('');
expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22');
expect(SSHTunnelServerPortInput).toHaveValue('22');
expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input',
);

View File

@ -20,8 +20,6 @@ import {
t,
styled,
SupersetTheme,
FeatureFlag,
isFeatureEnabled,
getExtensionsRegistry,
} from '@superset-ui/core';
import React, {
@ -31,6 +29,7 @@ import React, {
useState,
useReducer,
Reducer,
useCallback,
} from 'react';
import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
@ -65,6 +64,7 @@ import {
CatalogObject,
Engines,
ExtraJson,
CustomTextType,
} from '../types';
import ExtraOptions from './ExtraOptions';
import SqlAlchemyForm from './SqlAlchemyForm';
@ -208,8 +208,8 @@ export type DBReducerActionType =
| {
type:
| ActionType.Reset
| ActionType.AddTableCatalogSheet
| ActionType.RemoveSSHTunnelConfig;
| ActionType.RemoveSSHTunnelConfig
| ActionType.AddTableCatalogSheet;
}
| {
type: ActionType.RemoveTableCatalogSheet;
@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean>(false);
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean | undefined>(
undefined,
);
let dbConfigExtraExtension = extensionsRegistry.get(
'databaseconnection.extraOption',
@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const dbImages = getDatabaseImages();
const connectionAlert = getConnectionAlert();
const isEditMode = !!databaseId;
const disableSSHTunnelingForEngine = (
availableDbs?.databases?.find(
(DB: DatabaseObject) =>
DB.backend === db?.engine || DB.engine === db?.engine,
) as DatabaseObject
)?.engine_information?.disable_ssh_tunneling;
const isSSHTunneling =
isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine;
const hasAlert =
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
const useSqlAlchemyForm =
@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
extra: db?.extra,
masked_encrypted_extra: db?.masked_encrypted_extra || '',
server_cert: db?.server_cert || undefined,
ssh_tunnel: db?.ssh_tunnel || undefined,
ssh_tunnel:
!isEmpty(db?.ssh_tunnel) && useSSHTunneling
? {
...db.ssh_tunnel,
server_port: Number(db.ssh_tunnel!.server_port),
}
: undefined,
};
setTestInProgress(true);
testDatabaseConnection(
@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
return false;
};
const onChange = useCallback(
(
type: DBReducerActionType['type'],
payload: CustomTextType | DBReducerPayloadType,
) => {
setDB({ type, payload } as DBReducerActionType);
},
[],
);
const handleClearValidationErrors = useCallback(() => {
setValidationErrors(null);
}, [setValidationErrors]);
const handleParametersChange = useCallback(
({ target }: { target: HTMLInputElement }) => {
onChange(ActionType.ParametersChange, {
type: target.type,
name: target.name,
checked: target.checked,
value: target.value,
});
},
[onChange],
);
const onClose = () => {
setDB({ type: ActionType.Reset });
setHasConnectedDb(false);
setValidationErrors(null); // reset validation errors on close
handleClearValidationErrors(); // reset validation errors on close
clearError();
setEditNewDb(false);
setFileList([]);
@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
setConfirmedOverwrite(false);
setUseSSHTunneling(false);
setUseSSHTunneling(undefined);
onHide();
};
@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setImportingErrorMessage(msg);
});
const onChange = (type: any, payload: any) => {
setDB({ type, payload } as DBReducerActionType);
};
const onSave = async () => {
let dbConfigExtraExtensionOnSaveError;
setLoading(true);
dbConfigExtraExtension
?.onSave(extraExtensionComponentState, db)
.then(({ error }: { error: any }) => {
@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
addDangerToast(error);
}
});
if (dbConfigExtraExtensionOnSaveError) {
setLoading(false);
return;
@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
});
}
// only do validation for non ssh tunnel connections
if (!dbToUpdate?.ssh_tunnel) {
// make sure that button spinner animates
setLoading(true);
const errors = await getValidation(dbToUpdate, true);
if ((validationErrors && !isEmpty(validationErrors)) || errors) {
setLoading(false);
return;
}
// end spinner animation
const errors = await getValidation(dbToUpdate, true);
if (!isEmpty(validationErrors) || errors?.length) {
addDangerToast(
t('Connection failed, please check your connection settings.'),
);
setLoading(false);
return;
}
const parameters_schema = isEditMode
@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
});
}
setLoading(true);
// strictly checking for false as an indication that the toggle got unchecked
if (useSSHTunneling === false) {
// remove ssh tunnel
dbToUpdate.ssh_tunnel = null;
}
if (db?.id) {
const result = await updateResource(
db.id as number,
@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}, [sshPrivateKeyPasswordNeeded]);
useEffect(() => {
if (db && isSSHTunneling) {
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
if (db?.parameters?.ssh !== undefined) {
setUseSSHTunneling(db.parameters.ssh);
}
}, [db, isSSHTunneling]);
}, [db?.parameters?.ssh]);
const onDbImport = async (info: UploadChangeParam) => {
setImportingErrorMessage('');
@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const renderSSHTunnelForm = () => (
<SSHTunnelForm
db={db as DatabaseObject}
onSSHTunnelParametersChange={({
target,
}: {
target: HTMLInputElement | HTMLTextAreaElement;
}) =>
onSSHTunnelParametersChange={({ target }) => {
onChange(ActionType.ParametersSSHTunnelChange, {
type: target.type,
name: target.name,
value: target.value,
})
}
});
handleClearValidationErrors();
}}
setSSHTunnelLoginMethod={(method: AuthType) =>
setDB({
type: ActionType.SetSSHTunnelLoginMethod,
@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
payload: { indexToDelete: idx },
});
}}
onParametersChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.ParametersChange, {
type: target.type,
name: target.name,
checked: target.checked,
value: target.value,
})
}
onParametersChange={handleParametersChange}
onChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.TextChange, {
name: target.name,
@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
getValidation={() => getValidation(db)}
validationErrors={validationErrors}
getPlaceholder={getPlaceholder}
clearValidationErrors={() => setValidationErrors(null)}
clearValidationErrors={handleClearValidationErrors}
/>
{db?.parameters?.ssh && (
{useSSHTunneling && (
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
)}
</>
@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
testInProgress={testInProgress}
>
<SSHTunnelSwitchComponent
isEditMode={isEditMode}
dbFetched={dbFetched}
disableSSHTunnelingForEngine={disableSSHTunnelingForEngine}
useSSHTunneling={useSSHTunneling}
setUseSSHTunneling={setUseSSHTunneling}
setDB={setDB}
isSSHTunneling={isSSHTunneling}
dbModel={dbModel}
db={db as DatabaseObject}
changeMethods={{
onParametersChange: handleParametersChange,
}}
clearValidationErrors={handleClearValidationErrors}
/>
{useSSHTunneling && renderSSHTunnelForm()}
</SqlAlchemyForm>

View File

@ -1,3 +1,7 @@
import { JsonObject } from '@superset-ui/core';
import { InputProps } from 'antd/lib/input';
import { ChangeEvent, EventHandler, FormEvent } from 'react';
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@ -108,7 +112,7 @@ export type DatabaseObject = {
};
// SSH Tunnel information
ssh_tunnel?: SSHTunnelObject;
ssh_tunnel?: SSHTunnelObject | null;
};
export type DatabaseForm = {
@ -195,6 +199,10 @@ export type DatabaseForm = {
};
preferred: boolean;
sqlalchemy_uri_placeholder: string;
engine_information: {
supports_file_upload: boolean;
disable_ssh_tunneling: boolean;
};
};
// the values should align with the database
@ -231,3 +239,73 @@ export interface ExtraJson {
};
version?: string;
}
export type CustomTextType = {
value?: string | boolean | number;
type?: string | null;
name?: string;
checked?: boolean;
};
type CustomHTMLInputElement = Omit<Partial<CustomTextType>, 'value' | 'type'> &
CustomTextType;
type CustomHTMLTextAreaElement = Omit<
Partial<CustomTextType>,
'value' | 'type'
> &
CustomTextType;
export type CustomParametersChangeType<T = CustomTextType> =
| FormEvent<InputProps>
| { target: T };
export type CustomEventHandlerType = EventHandler<
ChangeEvent<CustomHTMLInputElement | CustomHTMLTextAreaElement>
>;
export interface FieldPropTypes {
required: boolean;
hasTooltip?: boolean;
tooltipText?: (value: any) => string;
placeholder?: string;
onParametersChange: (event: CustomParametersChangeType) => void;
onParametersUploadFileChange: (value: any) => string;
changeMethods: {
onParametersChange: (event: CustomParametersChangeType) => void;
} & {
onChange: (value: any) => string;
} & {
onQueryChange: (value: any) => string;
} & { onParametersUploadFileChange: (value: any) => string } & {
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: CustomEventHandlerType;
};
validationErrors: JsonObject | null;
getValidation: () => void;
clearValidationErrors: () => void;
db?: DatabaseObject;
dbModel?: DatabaseForm;
field: string;
isEditMode?: boolean;
sslForced?: boolean;
defaultDBName?: string;
editNewDb?: boolean;
}
type ChangeMethodsType = FieldPropTypes['changeMethods'];
// changeMethods compatibility with dynamic forms
type SwitchPropsChangeMethodsType = {
onParametersChange: ChangeMethodsType['onParametersChange'];
};
export type SwitchProps = {
dbModel: DatabaseForm;
db: DatabaseObject;
changeMethods: SwitchPropsChangeMethodsType;
clearValidationErrors: () => void;
};

View File

@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import SupersetText from 'src/utils/textUtils';
import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types';
import { DatabaseObject } from 'src/features/databases/types';
import { FavoriteStatus, ImportResourceName } from './types';
interface ListViewResourceState<D extends object = any> {
loading: boolean;
@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () =>
SupersetText.DB_CONNECTION_DOC_LINKS;
export const testDatabaseConnection = (
connection: DatabaseObject,
connection: Partial<DatabaseObject>,
handleErrorMsg: (errorMsg: string) => void,
addSuccessToast: (arg0: string) => void,
) => {
@ -745,7 +746,7 @@ export function useDatabaseValidation() {
const getValidation = useCallback(
(database: Partial<DatabaseObject> | null, onCreate = false) => {
if (database?.parameters?.ssh) {
// when ssh tunnel is enabled we don't want to render any validation errors
// TODO: /validate_parameters/ and related utils should support ssh tunnel
setValidationErrors(null);
return [];
}

View File

@ -19,6 +19,7 @@ from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as _
from marshmallow import ValidationError
from superset import is_feature_enabled
@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
)
@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand):
try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
except (
SupersetErrorsException,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand):
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand):
# Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
exception = DatabaseInvalidError()
exception.extend(exceptions)

View File

@ -23,11 +23,13 @@ from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
SSHTunnelRequiredFieldValidationError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.utils import make_url_safe
from superset.extensions import event_logger
from superset.models.core import Database
@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
_database: Database
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database"] = database
self._database = database
def run(self) -> Model:
try:
@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand):
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
password: Optional[str] = self._properties.get("password")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
url = make_url_safe(self._database.sqlalchemy_uri)
if not url.port:
raise SSHTunnelDatabasePortError()
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
if not username:
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
if not private_key and not password:
exceptions.append(SSHTunnelRequiredFieldValidationError("password"))
if private_key_password and private_key is None:
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
if exceptions:

View File

@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
message = _("SSH Tunnel parameters are invalid.")
class SSHTunnelDatabasePortError(CommandInvalidError):
message = _("A database port is required when connecting via SSH Tunnel.")
class SSHTunnelUpdateFailedError(UpdateFailedError):
message = _("SSH Tunnel could not be updated.")

View File

@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
SSHTunnelRequiredFieldValidationError,
@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
logger = logging.getLogger(__name__)
@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
def run(self) -> Model:
def run(self) -> Optional[Model]:
self.validate()
try:
if self._model is not None: # So we dont get incompatible types error
tunnel = SSHTunnelDAO.update(self._model, self._properties)
if self._model is None:
return None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
tunnel = SSHTunnelDAO.update(self._model, self._properties)
return tunnel
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
return tunnel
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()
url = make_url_safe(self._model.database.sqlalchemy_uri)
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)
if not url.port:
raise SSHTunnelDatabasePortError()

View File

@ -32,7 +32,10 @@ from superset.commands.database.exceptions import (
DatabaseTestConnectionDriverError,
DatabaseTestConnectionUnexpectedError,
)
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelingNotEnabledError,
)
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@ -61,20 +64,22 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand):
_model: Optional[Database] = None
_context: dict[str, Any]
_uri: str
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
self.validate()
ex_str = ""
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
ssh_tunnel = self._properties.get("ssh_tunnel")
# context for error messages
url = make_url_safe(uri)
context = {
"hostname": url.host,
"password": url.password,
@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
"database": url.database,
}
self._context = context
self._uri = uri
def run(self) -> None: # pylint: disable=too-many-statements
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
serialized_encrypted_extra = self._properties.get(
"masked_encrypted_extra",
"{}",
@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
encrypted_extra=serialized_encrypted_extra,
)
database.set_sqlalchemy_uri(uri)
database.set_sqlalchemy_uri(self._uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
# Generate tunnel if present in the properties
if ssh_tunnel:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
# If there's an existing tunnel for that DB we need to use the stored
# password, private_key and private_key_password instead
# unmask password while allowing for updated values
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
ssh_tunnel = unmask_password_info(
@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=database.db_engine_spec.__name__,
)
# check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, context)
errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex
except SupersetSecurityException as ex:
event_logger.log_with_context(
@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
),
engine=database.db_engine_spec.__name__,
)
errors = database.db_engine_spec.extract_errors(ex, context)
errors = database.db_engine_spec.extract_errors(ex, self._context)
raise DatabaseTestConnectionUnexpectedError(errors) from ex
def validate(self) -> None:
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
if self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
if not self._context.get("port"):
raise SSHTunnelDatabasePortError()

View File

@ -18,6 +18,7 @@ import logging
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as _
from marshmallow import ValidationError
from superset import is_feature_enabled
@ -30,8 +31,11 @@ from superset.commands.database.exceptions import (
DatabaseUpdateFailedError,
)
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
@ -47,15 +51,21 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
_model: Optional[Database]
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
def run(self) -> Model:
self.validate()
def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
self.validate()
old_database_name = self._model.database_name
# unmask ``encrypted_extra``
@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand):
database = DatabaseDAO.update(self._model, self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
if "ssh_tunnel" in self._properties:
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
# We need to remove the existing tunnel
try:
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
# So we can show the original message
DeleteSSHTunnelCommand(ssh_tunnel.id).run()
ssh_tunnel = None
except SSHTunnelDeleteFailedError as ex:
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if ssh_tunnel is None:
# We couldn't found an existing tunnel so we need to create one
try:
ssh_tunnel = CreateSSHTunnelCommand(
database, ssh_tunnel_properties
).run()
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
ssh_tunnel_id = ssh_tunnel.id
ssh_tunnel = UpdateSSHTunnelCommand(
ssh_tunnel_id, ssh_tunnel_properties
).run()
except (
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
SSHTunnelDatabasePortError,
) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
# adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails
try:
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
except Exception as ex:
db.session.rollback()
@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand):
def validate(self) -> None:
exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness

View File

@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
)
@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex:
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex:
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>", methods=("DELETE",))
@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
try:
TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK")
except SSHTunnelingNotEnabledError as ex:
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>/related_objects/", methods=("GET",))

View File

@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
# Cleanup
model = db.session.query(Database).get(response_create.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_delete_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_delete_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test deleting a SSH tunnel via Database update
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
mock_delete_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": None,
}
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)

View File

@ -19,7 +19,10 @@
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
def test_create_ssh_tunnel_command() -> None:
@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
properties = {
"database_id": database.id,
@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
def test_create_ssh_tunnel_command_no_port() -> None:
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)

View File

@ -20,11 +20,14 @@ from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
def session_with_data(request, session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
@pytest.mark.parametrize(
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test update"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)