fix: SSH Tunnel configuration settings (#27186)

This commit is contained in:
Geido 2024-03-11 16:56:54 +01:00 committed by GitHub
parent fde93dcf08
commit 89e89de341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 871 additions and 271 deletions

View File

@ -44,15 +44,15 @@ interface MenuObjectChildProps {
disable?: boolean; disable?: boolean;
} }
export interface SwitchProps { // loose typing to avoid any circular dependencies
isEditMode: boolean; // refer to SSHTunnelSwitch component for strict typing
dbFetched: any; type SwitchProps = {
disableSSHTunnelingForEngine?: boolean; db: object;
useSSHTunneling: boolean; changeMethods: {
setUseSSHTunneling: React.Dispatch<React.SetStateAction<boolean>>; onParametersChange: (event: any) => void;
setDB: React.Dispatch<any>; };
isSSHTunneling: boolean; clearValidationErrors: () => void;
} };
type ConfigDetailsProps = { type ConfigDetailsProps = {
embeddedId: string; embeddedId: string;

View File

@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => {
useRedux: true, useRedux: true,
}); });
userEvent.click(screen.getByTestId('schedule-panel')); userEvent.click(screen.getByTestId('schedule-panel'));
const days = screen.getAllByTitle(/day/i, { exact: true }); const day = screen.getByText('day');
expect(days.length).toBe(2); expect(day).toBeInTheDocument();
}); });
// Notification Method Section // Notification Method Section

View File

@ -17,12 +17,11 @@
* under the License. * under the License.
*/ */
import React from 'react'; import React from 'react';
import { isEmpty } from 'lodash';
import { SupersetTheme, t } from '@superset-ui/core'; import { SupersetTheme, t } from '@superset-ui/core';
import { AntdSwitch } from 'src/components'; import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip'; import InfoTooltip from 'src/components/InfoTooltip';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import { FieldPropTypes } from '.'; import { FieldPropTypes } from '../../types';
import { toggleStyle, infoTooltip } from '../styles'; import { toggleStyle, infoTooltip } from '../styles';
export const hostField = ({ export const hostField = ({
@ -252,35 +251,3 @@ export const forceSSLField = ({
/> />
</div> </div>
); );
export const SSHTunnelSwitch = ({
isEditMode,
changeMethods,
clearValidationErrors,
db,
}: FieldPropTypes) => (
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
<AntdSwitch
disabled={isEditMode && !isEmpty(db?.ssh_tunnel)}
checked={db?.parameters?.ssh}
onChange={changed => {
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: changed,
},
});
clearValidationErrors();
}}
data-test="ssh-tunnel-switch"
/>
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
<InfoTooltip
tooltip={t('SSH Tunnel configuration parameters')}
placement="right"
viewBox="0 -5 24 24"
/>
</div>
);

View File

@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip'; import InfoTooltip from 'src/components/InfoTooltip';
import FormLabel from 'src/components/Form/FormLabel'; import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons'; import Icons from 'src/components/Icons';
import { FieldPropTypes } from '.'; import { FieldPropTypes } from '../../types';
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles'; import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
enum CredentialInfoOptions { enum CredentialInfoOptions {

View File

@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import FormLabel from 'src/components/Form/FormLabel'; import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons'; import Icons from 'src/components/Icons';
import { FieldPropTypes } from '.';
import { StyledFooterButton, StyledCatalogTable } from '../styles'; import { StyledFooterButton, StyledCatalogTable } from '../styles';
import { CatalogObject } from '../../types'; import { CatalogObject, FieldPropTypes } from '../../types';
export const TableCatalog = ({ export const TableCatalog = ({
required, required,

View File

@ -19,7 +19,7 @@
import React from 'react'; import React from 'react';
import { t } from '@superset-ui/core'; import { t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import { FieldPropTypes } from '.'; import { FieldPropTypes } from '../../types';
const FIELD_TEXT_MAP = { const FIELD_TEXT_MAP = {
account: { account: {

View File

@ -17,7 +17,11 @@
* under the License. * under the License.
*/ */
import React, { FormEvent } from 'react'; import React, { FormEvent } from 'react';
import { SupersetTheme, JsonObject } from '@superset-ui/core'; import {
SupersetTheme,
JsonObject,
getExtensionsRegistry,
} from '@superset-ui/core';
import { InputProps } from 'antd/lib/input'; import { InputProps } from 'antd/lib/input';
import { Form } from 'src/components/Form'; import { Form } from 'src/components/Form';
import { import {
@ -31,13 +35,13 @@ import {
portField, portField,
queryField, queryField,
usernameField, usernameField,
SSHTunnelSwitch,
} from './CommonParameters'; } from './CommonParameters';
import { validatedInputField } from './ValidatedInputField'; import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField'; import { EncryptedField } from './EncryptedField';
import { TableCatalog } from './TableCatalog'; import { TableCatalog } from './TableCatalog';
import { formScrollableStyles, validatedFormStyles } from '../styles'; import { formScrollableStyles, validatedFormStyles } from '../styles';
import { DatabaseForm, DatabaseObject } from '../../types'; import { DatabaseForm, DatabaseObject } from '../../types';
import SSHTunnelSwitch from '../SSHTunnelSwitch';
export const FormFieldOrder = [ export const FormFieldOrder = [
'host', 'host',
@ -59,34 +63,10 @@ export const FormFieldOrder = [
'ssh', 'ssh',
]; ];
export interface FieldPropTypes { const extensionsRegistry = getExtensionsRegistry();
required: boolean;
hasTooltip?: boolean; const SSHTunnelSwitchComponent =
tooltipText?: (value: any) => string; extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
placeholder?: string;
onParametersChange: (value: any) => string;
onParametersUploadFileChange: (value: any) => string;
changeMethods: { onParametersChange: (value: any) => string } & {
onChange: (value: any) => string;
} & {
onQueryChange: (value: any) => string;
} & { onParametersUploadFileChange: (value: any) => string } & {
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: (value: any) => string;
};
validationErrors: JsonObject | null;
getValidation: () => void;
clearValidationErrors: () => void;
db?: DatabaseObject;
field: string;
isEditMode?: boolean;
sslForced?: boolean;
defaultDBName?: string;
editNewDb?: boolean;
}
const FORM_FIELD_MAP = { const FORM_FIELD_MAP = {
host: hostField, host: hostField,
@ -105,7 +85,7 @@ const FORM_FIELD_MAP = {
warehouse: validatedInputField, warehouse: validatedInputField,
role: validatedInputField, role: validatedInputField,
account: validatedInputField, account: validatedInputField,
ssh: SSHTunnelSwitch, ssh: SSHTunnelSwitchComponent,
}; };
interface DatabaseConnectionFormProps { interface DatabaseConnectionFormProps {
@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps {
} }
const DatabaseConnectionForm = ({ const DatabaseConnectionForm = ({
dbModel: { parameters }, dbModel,
db, db,
editNewDb, editNewDb,
getPlaceholder, getPlaceholder,
@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({
sslForced, sslForced,
validationErrors, validationErrors,
clearValidationErrors, clearValidationErrors,
}: DatabaseConnectionFormProps) => ( }: DatabaseConnectionFormProps) => {
<Form> const parameters = dbModel?.parameters;
<div
// @ts-ignore return (
css={(theme: SupersetTheme) => [ <Form>
formScrollableStyles, <div
validatedFormStyles(theme), // @ts-ignore
]} css={(theme: SupersetTheme) => [
> formScrollableStyles,
{parameters && validatedFormStyles(theme),
FormFieldOrder.filter( ]}
(key: string) => >
Object.keys(parameters.properties).includes(key) || {parameters &&
key === 'database_name', FormFieldOrder.filter(
).map(field => (key: string) =>
FORM_FIELD_MAP[field]({ Object.keys(parameters.properties).includes(key) ||
required: parameters.required?.includes(field), key === 'database_name',
changeMethods: { ).map(field =>
onParametersChange, FORM_FIELD_MAP[field]({
onChange, required: parameters.required?.includes(field),
onQueryChange, changeMethods: {
onParametersUploadFileChange, onParametersChange,
onAddTableCatalog, onChange,
onRemoveTableCatalog, onQueryChange,
onExtraInputChange, onParametersUploadFileChange,
}, onAddTableCatalog,
validationErrors, onRemoveTableCatalog,
getValidation, onExtraInputChange,
clearValidationErrors, },
db, validationErrors,
key: field, getValidation,
field, clearValidationErrors,
isEditMode, db,
sslForced, key: field,
editNewDb, field,
placeholder: getPlaceholder ? getPlaceholder(field) : undefined, isEditMode,
}), sslForced,
)} editNewDb,
</div> placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
</Form> }),
); )}
</div>
</Form>
);
};
export const FormFieldMap = FORM_FIELD_MAP; export const FormFieldMap = FORM_FIELD_MAP;
export default DatabaseConnectionForm; export default DatabaseConnectionForm;

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
import React, { EventHandler, ChangeEvent, useState } from 'react'; import React, { useState } from 'react';
import { t, styled } from '@superset-ui/core'; import { t, styled } from '@superset-ui/core';
import { AntdForm, Col, Row } from 'src/components'; import { AntdForm, Col, Row } from 'src/components';
import { Form, FormLabel } from 'src/components/Form'; import { Form, FormLabel } from 'src/components/Form';
@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio';
import { Input, TextArea } from 'src/components/Input'; import { Input, TextArea } from 'src/components/Input';
import { Input as AntdInput, Tooltip } from 'antd'; import { Input as AntdInput, Tooltip } from 'antd';
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons'; import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
import { DatabaseObject } from '../types'; import { DatabaseObject, FieldPropTypes } from '../types';
import { AuthType } from '.'; import { AuthType } from '.';
const StyledDiv = styled.div` const StyledDiv = styled.div`
@ -54,9 +54,7 @@ const SSHTunnelForm = ({
setSSHTunnelLoginMethod, setSSHTunnelLoginMethod,
}: { }: {
db: DatabaseObject | null; db: DatabaseObject | null;
onSSHTunnelParametersChange: EventHandler< onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange'];
ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
>;
setSSHTunnelLoginMethod: (method: AuthType) => void; setSSHTunnelLoginMethod: (method: AuthType) => void;
}) => { }) => {
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password); const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
@ -86,9 +84,9 @@ const SSHTunnelForm = ({
</FormLabel> </FormLabel>
<Input <Input
name="server_port" name="server_port"
type="text"
placeholder={t('22')} placeholder={t('22')}
value={db?.ssh_tunnel?.server_port || ''} type="number"
value={db?.ssh_tunnel?.server_port}
onChange={onSSHTunnelParametersChange} onChange={onSSHTunnelParametersChange}
data-test="ssh-tunnel-server_port-input" data-test="ssh-tunnel-server_port-input"
/> />

View File

@ -0,0 +1,162 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import React from 'react';
import { render, screen } from 'spec/helpers/testing-library';
import userEvent from '@testing-library/user-event';
import SSHTunnelSwitch from './SSHTunnelSwitch';
import { DatabaseForm, DatabaseObject } from '../types';
jest.mock('@superset-ui/core', () => ({
...jest.requireActual('@superset-ui/core'),
isFeatureEnabled: jest.fn().mockReturnValue(true),
}));
jest.mock('src/components', () => ({
AntdSwitch: ({
checked,
onChange,
}: {
checked: boolean;
onChange: (checked: boolean) => void;
}) => (
<button
onClick={() => onChange(!checked)}
aria-checked={checked}
role="switch"
type="button"
>
{checked ? 'ON' : 'OFF'}
</button>
),
}));
const mockChangeMethods = {
onParametersChange: jest.fn(),
};
const mockDbModel = {
engine: 'mysql',
engine_information: {
disable_ssh_tunneling: false,
},
} as DatabaseForm;
const defaultDb = {
parameters: { ssh: false },
ssh_tunnel: {},
engine: 'mysql',
} as DatabaseObject;
afterEach(() => {
jest.clearAllMocks();
});
test('Renders SSH Tunnel switch enabled by default and toggles its state', () => {
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={mockDbModel}
/>,
);
const switchButton = screen.getByRole('switch');
expect(switchButton).toHaveTextContent('OFF');
userEvent.click(switchButton);
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
});
expect(switchButton).toHaveTextContent('ON');
});
test('Does not render if SSH Tunnel is disabled', () => {
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={{
...mockDbModel,
engine_information: {
disable_ssh_tunneling: true,
supports_file_upload: false,
},
}}
/>,
);
expect(screen.queryByRole('switch')).not.toBeInTheDocument();
});
test('Checks the switch based on db.parameters.ssh', () => {
const dbWithSSHTunnelEnabled = {
...defaultDb,
parameters: { ssh: true },
} as DatabaseObject;
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={dbWithSSHTunnelEnabled}
dbModel={mockDbModel}
/>,
);
expect(screen.getByRole('switch')).toHaveTextContent('ON');
});
test('Calls onParametersChange with true if SSH Tunnel info exists', () => {
const dbWithSSHTunnelInfo = {
...defaultDb,
parameters: { ssh: undefined },
ssh_tunnel: { host: 'example.com' },
} as DatabaseObject;
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={dbWithSSHTunnelInfo}
dbModel={mockDbModel}
/>,
);
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
});
});
test('Displays tooltip text on hover over the InfoTooltip', async () => {
const tooltipText = 'SSH Tunnel configuration parameters';
render(
<SSHTunnelSwitch
changeMethods={mockChangeMethods}
clearValidationErrors={jest.fn}
db={defaultDb}
dbModel={mockDbModel}
/>,
);
const infoTooltipTrigger = screen.getByRole('img', {
name: 'info-solid_small',
});
expect(infoTooltipTrigger).toBeInTheDocument();
userEvent.hover(infoTooltipTrigger);
const tooltip = await screen.findByText(tooltipText);
expect(tooltip).toBeInTheDocument();
});

View File

@ -16,35 +16,73 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
import React from 'react'; import React, { useEffect, useState } from 'react';
import { t, SupersetTheme, SwitchProps } from '@superset-ui/core'; import {
t,
SupersetTheme,
isFeatureEnabled,
FeatureFlag,
} from '@superset-ui/core';
import { AntdSwitch } from 'src/components'; import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip'; import InfoTooltip from 'src/components/InfoTooltip';
import { isEmpty } from 'lodash'; import { isEmpty } from 'lodash';
import { ActionType } from '.';
import { infoTooltip, toggleStyle } from './styles'; import { infoTooltip, toggleStyle } from './styles';
import { SwitchProps } from '../types';
const SSHTunnelSwitch = ({ const SSHTunnelSwitch = ({
isEditMode, clearValidationErrors,
dbFetched, changeMethods,
useSSHTunneling, db,
setUseSSHTunneling, dbModel,
setDB, }: SwitchProps) => {
isSSHTunneling, const [isChecked, setChecked] = useState(false);
}: SwitchProps) => const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling);
isSSHTunneling ? ( const disableSSHTunnelingForEngine =
dbModel?.engine_information?.disable_ssh_tunneling || false;
const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine;
const handleOnChange = (changed: boolean) => {
setChecked(changed);
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: changed,
},
});
clearValidationErrors();
};
useEffect(() => {
if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) {
setChecked(db.parameters.ssh);
}
}, [db?.parameters?.ssh, isSSHTunnelEnabled]);
useEffect(() => {
if (
isSSHTunnelEnabled &&
db?.parameters?.ssh === undefined &&
!isEmpty(db?.ssh_tunnel)
) {
// reflecting the state of the ssh tunnel on first load
changeMethods.onParametersChange({
target: {
type: 'toggle',
name: 'ssh',
checked: true,
value: true,
},
});
}
}, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]);
return isSSHTunnelEnabled ? (
<div css={(theme: SupersetTheme) => infoTooltip(theme)}> <div css={(theme: SupersetTheme) => infoTooltip(theme)}>
<AntdSwitch <AntdSwitch
disabled={isEditMode && !isEmpty(dbFetched?.ssh_tunnel)} checked={isChecked}
checked={useSSHTunneling} onChange={handleOnChange}
onChange={changed => {
setUseSSHTunneling(changed);
if (!changed) {
setDB({
type: ActionType.RemoveSSHTunnelConfig,
});
}
}}
data-test="ssh-tunnel-switch" data-test="ssh-tunnel-switch"
/> />
<span css={toggleStyle}>{t('SSH Tunnel')}</span> <span css={toggleStyle}>{t('SSH Tunnel')}</span>
@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({
/> />
</div> </div>
) : null; ) : null;
};
export default SSHTunnelSwitch; export default SSHTunnelSwitch;

View File

@ -16,6 +16,9 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
// TODO: These tests should be made atomic in separate files
import React from 'react'; import React from 'react';
import fetchMock from 'fetch-mock'; import fetchMock from 'fetch-mock';
import userEvent from '@testing-library/user-event'; import userEvent from '@testing-library/user-event';
@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId( const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input', 'ssh-tunnel-server_port-input',
); );
expect(SSHTunnelServerPortInput).toHaveValue(''); expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22'); userEvent.type(SSHTunnelServerPortInput, '22');
expect(SSHTunnelServerPortInput).toHaveValue('22'); expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId( const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input', 'ssh-tunnel-username-input',
); );
@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId( const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input', 'ssh-tunnel-server_port-input',
); );
expect(SSHTunnelServerPortInput).toHaveValue(''); expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22'); userEvent.type(SSHTunnelServerPortInput, '22');
expect(SSHTunnelServerPortInput).toHaveValue('22'); expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId( const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input', 'ssh-tunnel-username-input',
); );

View File

@ -20,8 +20,6 @@ import {
t, t,
styled, styled,
SupersetTheme, SupersetTheme,
FeatureFlag,
isFeatureEnabled,
getExtensionsRegistry, getExtensionsRegistry,
} from '@superset-ui/core'; } from '@superset-ui/core';
import React, { import React, {
@ -31,6 +29,7 @@ import React, {
useState, useState,
useReducer, useReducer,
Reducer, Reducer,
useCallback,
} from 'react'; } from 'react';
import { useHistory } from 'react-router-dom'; import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers'; import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
@ -65,6 +64,7 @@ import {
CatalogObject, CatalogObject,
Engines, Engines,
ExtraJson, ExtraJson,
CustomTextType,
} from '../types'; } from '../types';
import ExtraOptions from './ExtraOptions'; import ExtraOptions from './ExtraOptions';
import SqlAlchemyForm from './SqlAlchemyForm'; import SqlAlchemyForm from './SqlAlchemyForm';
@ -208,8 +208,8 @@ export type DBReducerActionType =
| { | {
type: type:
| ActionType.Reset | ActionType.Reset
| ActionType.AddTableCatalogSheet | ActionType.RemoveSSHTunnelConfig
| ActionType.RemoveSSHTunnelConfig; | ActionType.AddTableCatalogSheet;
} }
| { | {
type: ActionType.RemoveTableCatalogSheet; type: ActionType.RemoveTableCatalogSheet;
@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const SSHTunnelSwitchComponent = const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch; extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean>(false); const [useSSHTunneling, setUseSSHTunneling] = useState<boolean | undefined>(
undefined,
);
let dbConfigExtraExtension = extensionsRegistry.get( let dbConfigExtraExtension = extensionsRegistry.get(
'databaseconnection.extraOption', 'databaseconnection.extraOption',
@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const dbImages = getDatabaseImages(); const dbImages = getDatabaseImages();
const connectionAlert = getConnectionAlert(); const connectionAlert = getConnectionAlert();
const isEditMode = !!databaseId; const isEditMode = !!databaseId;
const disableSSHTunnelingForEngine = (
availableDbs?.databases?.find(
(DB: DatabaseObject) =>
DB.backend === db?.engine || DB.engine === db?.engine,
) as DatabaseObject
)?.engine_information?.disable_ssh_tunneling;
const isSSHTunneling =
isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine;
const hasAlert = const hasAlert =
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]); connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
const useSqlAlchemyForm = const useSqlAlchemyForm =
@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
extra: db?.extra, extra: db?.extra,
masked_encrypted_extra: db?.masked_encrypted_extra || '', masked_encrypted_extra: db?.masked_encrypted_extra || '',
server_cert: db?.server_cert || undefined, server_cert: db?.server_cert || undefined,
ssh_tunnel: db?.ssh_tunnel || undefined, ssh_tunnel:
!isEmpty(db?.ssh_tunnel) && useSSHTunneling
? {
...db.ssh_tunnel,
server_port: Number(db.ssh_tunnel!.server_port),
}
: undefined,
}; };
setTestInProgress(true); setTestInProgress(true);
testDatabaseConnection( testDatabaseConnection(
@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
return false; return false;
}; };
const onChange = useCallback(
(
type: DBReducerActionType['type'],
payload: CustomTextType | DBReducerPayloadType,
) => {
setDB({ type, payload } as DBReducerActionType);
},
[],
);
const handleClearValidationErrors = useCallback(() => {
setValidationErrors(null);
}, [setValidationErrors]);
const handleParametersChange = useCallback(
({ target }: { target: HTMLInputElement }) => {
onChange(ActionType.ParametersChange, {
type: target.type,
name: target.name,
checked: target.checked,
value: target.value,
});
},
[onChange],
);
const onClose = () => { const onClose = () => {
setDB({ type: ActionType.Reset }); setDB({ type: ActionType.Reset });
setHasConnectedDb(false); setHasConnectedDb(false);
setValidationErrors(null); // reset validation errors on close handleClearValidationErrors(); // reset validation errors on close
clearError(); clearError();
setEditNewDb(false); setEditNewDb(false);
setFileList([]); setFileList([]);
@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setSSHTunnelPrivateKeys({}); setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({}); setSSHTunnelPrivateKeyPasswords({});
setConfirmedOverwrite(false); setConfirmedOverwrite(false);
setUseSSHTunneling(false); setUseSSHTunneling(undefined);
onHide(); onHide();
}; };
@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setImportingErrorMessage(msg); setImportingErrorMessage(msg);
}); });
const onChange = (type: any, payload: any) => {
setDB({ type, payload } as DBReducerActionType);
};
const onSave = async () => { const onSave = async () => {
let dbConfigExtraExtensionOnSaveError; let dbConfigExtraExtensionOnSaveError;
setLoading(true);
dbConfigExtraExtension dbConfigExtraExtension
?.onSave(extraExtensionComponentState, db) ?.onSave(extraExtensionComponentState, db)
.then(({ error }: { error: any }) => { .then(({ error }: { error: any }) => {
@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
addDangerToast(error); addDangerToast(error);
} }
}); });
if (dbConfigExtraExtensionOnSaveError) { if (dbConfigExtraExtensionOnSaveError) {
setLoading(false); setLoading(false);
return; return;
@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}); });
} }
// only do validation for non ssh tunnel connections const errors = await getValidation(dbToUpdate, true);
if (!dbToUpdate?.ssh_tunnel) { if (!isEmpty(validationErrors) || errors?.length) {
// make sure that button spinner animates addDangerToast(
setLoading(true); t('Connection failed, please check your connection settings.'),
const errors = await getValidation(dbToUpdate, true); );
if ((validationErrors && !isEmpty(validationErrors)) || errors) {
setLoading(false);
return;
}
// end spinner animation
setLoading(false); setLoading(false);
return;
} }
const parameters_schema = isEditMode const parameters_schema = isEditMode
@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}); });
} }
setLoading(true); // strictly checking for false as an indication that the toggle got unchecked
if (useSSHTunneling === false) {
// remove ssh tunnel
dbToUpdate.ssh_tunnel = null;
}
if (db?.id) { if (db?.id) {
const result = await updateResource( const result = await updateResource(
db.id as number, db.id as number,
@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}, [sshPrivateKeyPasswordNeeded]); }, [sshPrivateKeyPasswordNeeded]);
useEffect(() => { useEffect(() => {
if (db && isSSHTunneling) { if (db?.parameters?.ssh !== undefined) {
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel)); setUseSSHTunneling(db.parameters.ssh);
} }
}, [db, isSSHTunneling]); }, [db?.parameters?.ssh]);
const onDbImport = async (info: UploadChangeParam) => { const onDbImport = async (info: UploadChangeParam) => {
setImportingErrorMessage(''); setImportingErrorMessage('');
@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const renderSSHTunnelForm = () => ( const renderSSHTunnelForm = () => (
<SSHTunnelForm <SSHTunnelForm
db={db as DatabaseObject} db={db as DatabaseObject}
onSSHTunnelParametersChange={({ onSSHTunnelParametersChange={({ target }) => {
target,
}: {
target: HTMLInputElement | HTMLTextAreaElement;
}) =>
onChange(ActionType.ParametersSSHTunnelChange, { onChange(ActionType.ParametersSSHTunnelChange, {
type: target.type, type: target.type,
name: target.name, name: target.name,
value: target.value, value: target.value,
}) });
} handleClearValidationErrors();
}}
setSSHTunnelLoginMethod={(method: AuthType) => setSSHTunnelLoginMethod={(method: AuthType) =>
setDB({ setDB({
type: ActionType.SetSSHTunnelLoginMethod, type: ActionType.SetSSHTunnelLoginMethod,
@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
payload: { indexToDelete: idx }, payload: { indexToDelete: idx },
}); });
}} }}
onParametersChange={({ target }: { target: HTMLInputElement }) => onParametersChange={handleParametersChange}
onChange(ActionType.ParametersChange, {
type: target.type,
name: target.name,
checked: target.checked,
value: target.value,
})
}
onChange={({ target }: { target: HTMLInputElement }) => onChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.TextChange, { onChange(ActionType.TextChange, {
name: target.name, name: target.name,
@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
getValidation={() => getValidation(db)} getValidation={() => getValidation(db)}
validationErrors={validationErrors} validationErrors={validationErrors}
getPlaceholder={getPlaceholder} getPlaceholder={getPlaceholder}
clearValidationErrors={() => setValidationErrors(null)} clearValidationErrors={handleClearValidationErrors}
/> />
{db?.parameters?.ssh && ( {useSSHTunneling && (
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer> <SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
)} )}
</> </>
@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
testInProgress={testInProgress} testInProgress={testInProgress}
> >
<SSHTunnelSwitchComponent <SSHTunnelSwitchComponent
isEditMode={isEditMode} dbModel={dbModel}
dbFetched={dbFetched} db={db as DatabaseObject}
disableSSHTunnelingForEngine={disableSSHTunnelingForEngine} changeMethods={{
useSSHTunneling={useSSHTunneling} onParametersChange: handleParametersChange,
setUseSSHTunneling={setUseSSHTunneling} }}
setDB={setDB} clearValidationErrors={handleClearValidationErrors}
isSSHTunneling={isSSHTunneling}
/> />
{useSSHTunneling && renderSSHTunnelForm()} {useSSHTunneling && renderSSHTunnelForm()}
</SqlAlchemyForm> </SqlAlchemyForm>

View File

@ -1,3 +1,7 @@
import { JsonObject } from '@superset-ui/core';
import { InputProps } from 'antd/lib/input';
import { ChangeEvent, EventHandler, FormEvent } from 'react';
/** /**
* Licensed to the Apache Software Foundation (ASF) under one * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file * or more contributor license agreements. See the NOTICE file
@ -108,7 +112,7 @@ export type DatabaseObject = {
}; };
// SSH Tunnel information // SSH Tunnel information
ssh_tunnel?: SSHTunnelObject; ssh_tunnel?: SSHTunnelObject | null;
}; };
export type DatabaseForm = { export type DatabaseForm = {
@ -195,6 +199,10 @@ export type DatabaseForm = {
}; };
preferred: boolean; preferred: boolean;
sqlalchemy_uri_placeholder: string; sqlalchemy_uri_placeholder: string;
engine_information: {
supports_file_upload: boolean;
disable_ssh_tunneling: boolean;
};
}; };
// the values should align with the database // the values should align with the database
@ -231,3 +239,73 @@ export interface ExtraJson {
}; };
version?: string; version?: string;
} }
export type CustomTextType = {
value?: string | boolean | number;
type?: string | null;
name?: string;
checked?: boolean;
};
type CustomHTMLInputElement = Omit<Partial<CustomTextType>, 'value' | 'type'> &
CustomTextType;
type CustomHTMLTextAreaElement = Omit<
Partial<CustomTextType>,
'value' | 'type'
> &
CustomTextType;
export type CustomParametersChangeType<T = CustomTextType> =
| FormEvent<InputProps>
| { target: T };
export type CustomEventHandlerType = EventHandler<
ChangeEvent<CustomHTMLInputElement | CustomHTMLTextAreaElement>
>;
export interface FieldPropTypes {
required: boolean;
hasTooltip?: boolean;
tooltipText?: (value: any) => string;
placeholder?: string;
onParametersChange: (event: CustomParametersChangeType) => void;
onParametersUploadFileChange: (value: any) => string;
changeMethods: {
onParametersChange: (event: CustomParametersChangeType) => void;
} & {
onChange: (value: any) => string;
} & {
onQueryChange: (value: any) => string;
} & { onParametersUploadFileChange: (value: any) => string } & {
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: CustomEventHandlerType;
};
validationErrors: JsonObject | null;
getValidation: () => void;
clearValidationErrors: () => void;
db?: DatabaseObject;
dbModel?: DatabaseForm;
field: string;
isEditMode?: boolean;
sslForced?: boolean;
defaultDBName?: string;
editNewDb?: boolean;
}
type ChangeMethodsType = FieldPropTypes['changeMethods'];
// changeMethods compatibility with dynamic forms
type SwitchPropsChangeMethodsType = {
onParametersChange: ChangeMethodsType['onParametersChange'];
};
export type SwitchProps = {
dbModel: DatabaseForm;
db: DatabaseObject;
changeMethods: SwitchPropsChangeMethodsType;
clearValidationErrors: () => void;
};

View File

@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy'; import copyTextToClipboard from 'src/utils/copy';
import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import SupersetText from 'src/utils/textUtils'; import SupersetText from 'src/utils/textUtils';
import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types'; import { DatabaseObject } from 'src/features/databases/types';
import { FavoriteStatus, ImportResourceName } from './types';
interface ListViewResourceState<D extends object = any> { interface ListViewResourceState<D extends object = any> {
loading: boolean; loading: boolean;
@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () =>
SupersetText.DB_CONNECTION_DOC_LINKS; SupersetText.DB_CONNECTION_DOC_LINKS;
export const testDatabaseConnection = ( export const testDatabaseConnection = (
connection: DatabaseObject, connection: Partial<DatabaseObject>,
handleErrorMsg: (errorMsg: string) => void, handleErrorMsg: (errorMsg: string) => void,
addSuccessToast: (arg0: string) => void, addSuccessToast: (arg0: string) => void,
) => { ) => {
@ -745,7 +746,7 @@ export function useDatabaseValidation() {
const getValidation = useCallback( const getValidation = useCallback(
(database: Partial<DatabaseObject> | null, onCreate = false) => { (database: Partial<DatabaseObject> | null, onCreate = false) => {
if (database?.parameters?.ssh) { if (database?.parameters?.ssh) {
// when ssh tunnel is enabled we don't want to render any validation errors // TODO: /validate_parameters/ and related utils should support ssh tunnel
setValidationErrors(null); setValidationErrors(null);
return []; return [];
} }

View File

@ -19,6 +19,7 @@ from typing import Any, Optional
from flask import current_app from flask import current_app
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as _
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import is_feature_enabled from superset import is_feature_enabled
@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError, SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
SSHTunnelInvalidError, SSHTunnelInvalidError,
) )
@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand):
try: try:
# Test connection before starting create transaction # Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run() TestConnectionDatabaseCommand(self._properties).run()
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex: except (
SupersetErrorsException,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex:
event_logger.log_with_context( event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}", action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand):
SSHTunnelInvalidError, SSHTunnelInvalidError,
SSHTunnelCreateFailedError, SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex: ) as ex:
db.session.rollback() db.session.rollback()
event_logger.log_with_context( event_logger.log_with_context(
@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand):
# Check database_name uniqueness # Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name): if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError()) exceptions.append(DatabaseExistsValidationError())
if exceptions: if exceptions:
exception = DatabaseInvalidError() exception = DatabaseInvalidError()
exception.extend(exceptions) exception.extend(exceptions)

View File

@ -23,11 +23,13 @@ from marshmallow import ValidationError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError, SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelInvalidError, SSHTunnelInvalidError,
SSHTunnelRequiredFieldValidationError, SSHTunnelRequiredFieldValidationError,
) )
from superset.daos.database import SSHTunnelDAO from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.utils import make_url_safe
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.models.core import Database from superset.models.core import Database
@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand): class CreateSSHTunnelCommand(BaseCommand):
_database: Database
def __init__(self, database: Database, data: dict[str, Any]): def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
self._properties["database"] = database self._properties["database"] = database
self._database = database
def run(self) -> Model: def run(self) -> Model:
try: try:
@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand):
server_address: Optional[str] = self._properties.get("server_address") server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port") server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username") username: Optional[str] = self._properties.get("username")
password: Optional[str] = self._properties.get("password")
private_key: Optional[str] = self._properties.get("private_key") private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get( private_key_password: Optional[str] = self._properties.get(
"private_key_password" "private_key_password"
) )
url = make_url_safe(self._database.sqlalchemy_uri)
if not url.port:
raise SSHTunnelDatabasePortError()
if not server_address: if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port: if not server_port:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
if not username: if not username:
exceptions.append(SSHTunnelRequiredFieldValidationError("username")) exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
if not private_key and not password:
exceptions.append(SSHTunnelRequiredFieldValidationError("password"))
if private_key_password and private_key is None: if private_key_password and private_key is None:
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
if exceptions: if exceptions:

View File

@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
message = _("SSH Tunnel parameters are invalid.") message = _("SSH Tunnel parameters are invalid.")
class SSHTunnelDatabasePortError(CommandInvalidError):
message = _("A database port is required when connecting via SSH Tunnel.")
class SSHTunnelUpdateFailedError(UpdateFailedError): class SSHTunnelUpdateFailedError(UpdateFailedError):
message = _("SSH Tunnel could not be updated.") message = _("SSH Tunnel could not be updated.")

View File

@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError, SSHTunnelInvalidError,
SSHTunnelNotFoundError, SSHTunnelNotFoundError,
SSHTunnelRequiredFieldValidationError, SSHTunnelRequiredFieldValidationError,
@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
from superset.daos.database import SSHTunnelDAO from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SSHTunnel] = None self._model: Optional[SSHTunnel] = None
def run(self) -> Model: def run(self) -> Optional[Model]:
self.validate() self.validate()
try: try:
if self._model is not None: # So we dont get incompatible types error if self._model is None:
tunnel = SSHTunnelDAO.update(self._model, self._properties) return None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
tunnel = SSHTunnelDAO.update(self._model, self._properties)
return tunnel
except DAOUpdateFailedError as ex: except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex raise SSHTunnelUpdateFailedError() from ex
return tunnel
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id) self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model: if not self._model:
raise SSHTunnelNotFoundError() raise SSHTunnelNotFoundError()
url = make_url_safe(self._model.database.sqlalchemy_uri)
private_key: Optional[str] = self._properties.get("private_key") private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get( private_key_password: Optional[str] = self._properties.get(
"private_key_password" "private_key_password"
@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
raise SSHTunnelInvalidError( raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
) )
if not url.port:
raise SSHTunnelDatabasePortError()

View File

@ -32,7 +32,10 @@ from superset.commands.database.exceptions import (
DatabaseTestConnectionDriverError, DatabaseTestConnectionDriverError,
DatabaseTestConnectionUnexpectedError, DatabaseTestConnectionUnexpectedError,
) )
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelingNotEnabledError,
)
from superset.daos.database import DatabaseDAO, SSHTunnelDAO from superset.daos.database import DatabaseDAO, SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
@ -61,20 +64,22 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand): class TestConnectionDatabaseCommand(BaseCommand):
_model: Optional[Database] = None
_context: dict[str, Any]
_uri: str
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Database] = None
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches if (database_name := self._properties.get("database_name")) is not None:
self.validate() self._model = DatabaseDAO.get_database_by_name(database_name)
ex_str = ""
uri = self._properties.get("sqlalchemy_uri", "") uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri(): if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted uri = self._model.sqlalchemy_uri_decrypted
ssh_tunnel = self._properties.get("ssh_tunnel")
# context for error messages
url = make_url_safe(uri) url = make_url_safe(uri)
context = { context = {
"hostname": url.host, "hostname": url.host,
"password": url.password, "password": url.password,
@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
"database": url.database, "database": url.database,
} }
self._context = context
self._uri = uri
def run(self) -> None: # pylint: disable=too-many-statements
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
serialized_encrypted_extra = self._properties.get( serialized_encrypted_extra = self._properties.get(
"masked_encrypted_extra", "masked_encrypted_extra",
"{}", "{}",
@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
encrypted_extra=serialized_encrypted_extra, encrypted_extra=serialized_encrypted_extra,
) )
database.set_sqlalchemy_uri(uri) database.set_sqlalchemy_uri(self._uri)
database.db_engine_spec.mutate_db_for_connection_test(database) database.db_engine_spec.mutate_db_for_connection_test(database)
# Generate tunnel if present in the properties # Generate tunnel if present in the properties
if ssh_tunnel: if ssh_tunnel:
if not is_feature_enabled("SSH_TUNNELING"): # unmask password while allowing for updated values
raise SSHTunnelingNotEnabledError()
# If there's an existing tunnel for that DB we need to use the stored
# password, private_key and private_key_password instead
if ssh_tunnel_id := ssh_tunnel.pop("id", None): if ssh_tunnel_id := ssh_tunnel.pop("id", None):
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id): if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
ssh_tunnel = unmask_password_info( ssh_tunnel = unmask_password_info(
@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=database.db_engine_spec.__name__, engine=database.db_engine_spec.__name__,
) )
# check for custom errors (wrong username, wrong password, etc) # check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, context) errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex raise SupersetErrorsException(errors) from ex
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
event_logger.log_with_context( event_logger.log_with_context(
@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
), ),
engine=database.db_engine_spec.__name__, engine=database.db_engine_spec.__name__,
) )
errors = database.db_engine_spec.extract_errors(ex, context) errors = database.db_engine_spec.extract_errors(ex, self._context)
raise DatabaseTestConnectionUnexpectedError(errors) from ex raise DatabaseTestConnectionUnexpectedError(errors) from ex
def validate(self) -> None: def validate(self) -> None:
if (database_name := self._properties.get("database_name")) is not None: if self._properties.get("ssh_tunnel"):
self._model = DatabaseDAO.get_database_by_name(database_name) if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
if not self._context.get("port"):
raise SSHTunnelDatabasePortError()

View File

@ -18,6 +18,7 @@ import logging
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as _
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import is_feature_enabled from superset import is_feature_enabled
@ -30,8 +31,11 @@ from superset.commands.database.exceptions import (
DatabaseUpdateFailedError, DatabaseUpdateFailedError,
) )
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError, SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
SSHTunnelInvalidError, SSHTunnelInvalidError,
SSHTunnelUpdateFailedError, SSHTunnelUpdateFailedError,
@ -47,15 +51,21 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand): class UpdateDatabaseCommand(BaseCommand):
_model: Optional[Database]
def __init__(self, model_id: int, data: dict[str, Any]): def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
self._model_id = model_id self._model_id = model_id
self._model: Optional[Database] = None self._model: Optional[Database] = None
def run(self) -> Model: def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches
self.validate() self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model: if not self._model:
raise DatabaseNotFoundError() raise DatabaseNotFoundError()
self.validate()
old_database_name = self._model.database_name old_database_name = self._model.database_name
# unmask ``encrypted_extra`` # unmask ``encrypted_extra``
@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand):
database = DatabaseDAO.update(self._model, self._properties, commit=False) database = DatabaseDAO.update(self._model, self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri) database.set_sqlalchemy_uri(database.sqlalchemy_uri)
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
if "ssh_tunnel" in self._properties:
if not is_feature_enabled("SSH_TUNNELING"): if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback() db.session.rollback()
raise SSHTunnelingNotEnabledError() raise SSHTunnelingNotEnabledError()
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None: if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
# We couldn't found an existing tunnel so we need to create one # We need to remove the existing tunnel
try: try:
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() DeleteSSHTunnelCommand(ssh_tunnel.id).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: ssh_tunnel = None
# So we can show the original message except SSHTunnelDeleteFailedError as ex:
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
# So we can show the original message
raise ex raise ex
except Exception as ex: except Exception as ex:
raise DatabaseUpdateFailedError() from ex raise DatabaseUpdateFailedError() from ex
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if ssh_tunnel is None:
# We couldn't found an existing tunnel so we need to create one
try:
ssh_tunnel = CreateSSHTunnelCommand(
database, ssh_tunnel_properties
).run()
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
ssh_tunnel_id = ssh_tunnel.id
ssh_tunnel = UpdateSSHTunnelCommand(
ssh_tunnel_id, ssh_tunnel_properties
).run()
except (
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
SSHTunnelDatabasePortError,
) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
# adding a new database we always want to force refresh schema list # adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails # TODO Improve this simplistic implementation for catching DB conn fails
try: try:
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand):
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
database_name: Optional[str] = self._properties.get("database_name") database_name: Optional[str] = self._properties.get("database_name")
if database_name: if database_name:
# Check database_name uniqueness # Check database_name uniqueness

View File

@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelDeleteFailedError, SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
) )
@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True, exc_info=True,
) )
return self.response_422(message=str(ex)) return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex: except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
except SupersetException as ex: except SupersetException as ex:
return self.response(ex.status, message=ex.message) return self.response(ex.status, message=ex.message)
@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True, exc_info=True,
) )
return self.response_422(message=str(ex)) return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex: except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
@expose("/<int:pk>", methods=("DELETE",)) @expose("/<int:pk>", methods=("DELETE",))
@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
try: try:
TestConnectionDatabaseCommand(item).run() TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except SSHTunnelingNotEnabledError as ex: except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
@expose("/<int:pk>/related_objects/", methods=("GET",)) @expose("/<int:pk>/related_objects/", methods=("GET",))

View File

@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func from sqlalchemy.sql import func
from superset import db, security_manager from superset import db, security_manager
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model) db.session.delete(model)
db.session.commit() db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
@mock.patch( @mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
) )
@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model) db.session.delete(model)
db.session.commit() db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
# Cleanup
model = db.session.query(Database).get(response_create.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_delete_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_delete_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test deleting a SSH tunnel via Database update
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
mock_delete_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": None,
}
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch( @mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
) )

View File

@ -19,7 +19,10 @@
import pytest import pytest
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
def test_create_ssh_tunnel_command() -> None: def test_create_ssh_tunnel_command() -> None:
@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
properties = { properties = {
"database_id": database.id, "database_id": database.id,
@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
# If we are trying to create a tunnel with a private_key_password # If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory # then a private_key is mandatory
@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo: with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run() command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
def test_create_ssh_tunnel_command_no_port() -> None:
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)

View File

@ -20,11 +20,14 @@ from collections.abc import Iterator
import pytest import pytest
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
@pytest.fixture @pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]: def session_with_data(request, session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database from superset.models.core import Database
@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind() engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
sqla_table = SqlaTable( sqla_table = SqlaTable(
table_name="my_sqla_table", table_name="my_sqla_table",
columns=[], columns=[],
@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo: with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run() command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
@pytest.mark.parametrize(
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test update"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)