mirror of https://github.com/apache/superset.git
fix: SSH Tunnel configuration settings (#27186)
This commit is contained in:
parent
fde93dcf08
commit
89e89de341
|
@ -44,15 +44,15 @@ interface MenuObjectChildProps {
|
|||
disable?: boolean;
|
||||
}
|
||||
|
||||
export interface SwitchProps {
|
||||
isEditMode: boolean;
|
||||
dbFetched: any;
|
||||
disableSSHTunnelingForEngine?: boolean;
|
||||
useSSHTunneling: boolean;
|
||||
setUseSSHTunneling: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setDB: React.Dispatch<any>;
|
||||
isSSHTunneling: boolean;
|
||||
}
|
||||
// loose typing to avoid any circular dependencies
|
||||
// refer to SSHTunnelSwitch component for strict typing
|
||||
type SwitchProps = {
|
||||
db: object;
|
||||
changeMethods: {
|
||||
onParametersChange: (event: any) => void;
|
||||
};
|
||||
clearValidationErrors: () => void;
|
||||
};
|
||||
|
||||
type ConfigDetailsProps = {
|
||||
embeddedId: string;
|
||||
|
|
|
@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => {
|
|||
useRedux: true,
|
||||
});
|
||||
userEvent.click(screen.getByTestId('schedule-panel'));
|
||||
const days = screen.getAllByTitle(/day/i, { exact: true });
|
||||
expect(days.length).toBe(2);
|
||||
const day = screen.getByText('day');
|
||||
expect(day).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Notification Method Section
|
||||
|
|
|
@ -17,12 +17,11 @@
|
|||
* under the License.
|
||||
*/
|
||||
import React from 'react';
|
||||
import { isEmpty } from 'lodash';
|
||||
import { SupersetTheme, t } from '@superset-ui/core';
|
||||
import { AntdSwitch } from 'src/components';
|
||||
import InfoTooltip from 'src/components/InfoTooltip';
|
||||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||
import { FieldPropTypes } from '.';
|
||||
import { FieldPropTypes } from '../../types';
|
||||
import { toggleStyle, infoTooltip } from '../styles';
|
||||
|
||||
export const hostField = ({
|
||||
|
@ -252,35 +251,3 @@ export const forceSSLField = ({
|
|||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const SSHTunnelSwitch = ({
|
||||
isEditMode,
|
||||
changeMethods,
|
||||
clearValidationErrors,
|
||||
db,
|
||||
}: FieldPropTypes) => (
|
||||
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
|
||||
<AntdSwitch
|
||||
disabled={isEditMode && !isEmpty(db?.ssh_tunnel)}
|
||||
checked={db?.parameters?.ssh}
|
||||
onChange={changed => {
|
||||
changeMethods.onParametersChange({
|
||||
target: {
|
||||
type: 'toggle',
|
||||
name: 'ssh',
|
||||
checked: true,
|
||||
value: changed,
|
||||
},
|
||||
});
|
||||
clearValidationErrors();
|
||||
}}
|
||||
data-test="ssh-tunnel-switch"
|
||||
/>
|
||||
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
|
||||
<InfoTooltip
|
||||
tooltip={t('SSH Tunnel configuration parameters')}
|
||||
placement="right"
|
||||
viewBox="0 -5 24 24"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
|
@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components';
|
|||
import InfoTooltip from 'src/components/InfoTooltip';
|
||||
import FormLabel from 'src/components/Form/FormLabel';
|
||||
import Icons from 'src/components/Icons';
|
||||
import { FieldPropTypes } from '.';
|
||||
import { FieldPropTypes } from '../../types';
|
||||
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
|
||||
|
||||
enum CredentialInfoOptions {
|
||||
|
|
|
@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core';
|
|||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||
import FormLabel from 'src/components/Form/FormLabel';
|
||||
import Icons from 'src/components/Icons';
|
||||
import { FieldPropTypes } from '.';
|
||||
import { StyledFooterButton, StyledCatalogTable } from '../styles';
|
||||
import { CatalogObject } from '../../types';
|
||||
import { CatalogObject, FieldPropTypes } from '../../types';
|
||||
|
||||
export const TableCatalog = ({
|
||||
required,
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
import React from 'react';
|
||||
import { t } from '@superset-ui/core';
|
||||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||
import { FieldPropTypes } from '.';
|
||||
import { FieldPropTypes } from '../../types';
|
||||
|
||||
const FIELD_TEXT_MAP = {
|
||||
account: {
|
||||
|
|
|
@ -17,7 +17,11 @@
|
|||
* under the License.
|
||||
*/
|
||||
import React, { FormEvent } from 'react';
|
||||
import { SupersetTheme, JsonObject } from '@superset-ui/core';
|
||||
import {
|
||||
SupersetTheme,
|
||||
JsonObject,
|
||||
getExtensionsRegistry,
|
||||
} from '@superset-ui/core';
|
||||
import { InputProps } from 'antd/lib/input';
|
||||
import { Form } from 'src/components/Form';
|
||||
import {
|
||||
|
@ -31,13 +35,13 @@ import {
|
|||
portField,
|
||||
queryField,
|
||||
usernameField,
|
||||
SSHTunnelSwitch,
|
||||
} from './CommonParameters';
|
||||
import { validatedInputField } from './ValidatedInputField';
|
||||
import { EncryptedField } from './EncryptedField';
|
||||
import { TableCatalog } from './TableCatalog';
|
||||
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
||||
import { DatabaseForm, DatabaseObject } from '../../types';
|
||||
import SSHTunnelSwitch from '../SSHTunnelSwitch';
|
||||
|
||||
export const FormFieldOrder = [
|
||||
'host',
|
||||
|
@ -59,34 +63,10 @@ export const FormFieldOrder = [
|
|||
'ssh',
|
||||
];
|
||||
|
||||
export interface FieldPropTypes {
|
||||
required: boolean;
|
||||
hasTooltip?: boolean;
|
||||
tooltipText?: (value: any) => string;
|
||||
placeholder?: string;
|
||||
onParametersChange: (value: any) => string;
|
||||
onParametersUploadFileChange: (value: any) => string;
|
||||
changeMethods: { onParametersChange: (value: any) => string } & {
|
||||
onChange: (value: any) => string;
|
||||
} & {
|
||||
onQueryChange: (value: any) => string;
|
||||
} & { onParametersUploadFileChange: (value: any) => string } & {
|
||||
onAddTableCatalog: () => void;
|
||||
onRemoveTableCatalog: (idx: number) => void;
|
||||
} & {
|
||||
onExtraInputChange: (value: any) => void;
|
||||
onSSHTunnelParametersChange: (value: any) => string;
|
||||
};
|
||||
validationErrors: JsonObject | null;
|
||||
getValidation: () => void;
|
||||
clearValidationErrors: () => void;
|
||||
db?: DatabaseObject;
|
||||
field: string;
|
||||
isEditMode?: boolean;
|
||||
sslForced?: boolean;
|
||||
defaultDBName?: string;
|
||||
editNewDb?: boolean;
|
||||
}
|
||||
const extensionsRegistry = getExtensionsRegistry();
|
||||
|
||||
const SSHTunnelSwitchComponent =
|
||||
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
|
||||
|
||||
const FORM_FIELD_MAP = {
|
||||
host: hostField,
|
||||
|
@ -105,7 +85,7 @@ const FORM_FIELD_MAP = {
|
|||
warehouse: validatedInputField,
|
||||
role: validatedInputField,
|
||||
account: validatedInputField,
|
||||
ssh: SSHTunnelSwitch,
|
||||
ssh: SSHTunnelSwitchComponent,
|
||||
};
|
||||
|
||||
interface DatabaseConnectionFormProps {
|
||||
|
@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps {
|
|||
}
|
||||
|
||||
const DatabaseConnectionForm = ({
|
||||
dbModel: { parameters },
|
||||
dbModel,
|
||||
db,
|
||||
editNewDb,
|
||||
getPlaceholder,
|
||||
|
@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({
|
|||
sslForced,
|
||||
validationErrors,
|
||||
clearValidationErrors,
|
||||
}: DatabaseConnectionFormProps) => (
|
||||
<Form>
|
||||
<div
|
||||
// @ts-ignore
|
||||
css={(theme: SupersetTheme) => [
|
||||
formScrollableStyles,
|
||||
validatedFormStyles(theme),
|
||||
]}
|
||||
>
|
||||
{parameters &&
|
||||
FormFieldOrder.filter(
|
||||
(key: string) =>
|
||||
Object.keys(parameters.properties).includes(key) ||
|
||||
key === 'database_name',
|
||||
).map(field =>
|
||||
FORM_FIELD_MAP[field]({
|
||||
required: parameters.required?.includes(field),
|
||||
changeMethods: {
|
||||
onParametersChange,
|
||||
onChange,
|
||||
onQueryChange,
|
||||
onParametersUploadFileChange,
|
||||
onAddTableCatalog,
|
||||
onRemoveTableCatalog,
|
||||
onExtraInputChange,
|
||||
},
|
||||
validationErrors,
|
||||
getValidation,
|
||||
clearValidationErrors,
|
||||
db,
|
||||
key: field,
|
||||
field,
|
||||
isEditMode,
|
||||
sslForced,
|
||||
editNewDb,
|
||||
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
</Form>
|
||||
);
|
||||
}: DatabaseConnectionFormProps) => {
|
||||
const parameters = dbModel?.parameters;
|
||||
|
||||
return (
|
||||
<Form>
|
||||
<div
|
||||
// @ts-ignore
|
||||
css={(theme: SupersetTheme) => [
|
||||
formScrollableStyles,
|
||||
validatedFormStyles(theme),
|
||||
]}
|
||||
>
|
||||
{parameters &&
|
||||
FormFieldOrder.filter(
|
||||
(key: string) =>
|
||||
Object.keys(parameters.properties).includes(key) ||
|
||||
key === 'database_name',
|
||||
).map(field =>
|
||||
FORM_FIELD_MAP[field]({
|
||||
required: parameters.required?.includes(field),
|
||||
changeMethods: {
|
||||
onParametersChange,
|
||||
onChange,
|
||||
onQueryChange,
|
||||
onParametersUploadFileChange,
|
||||
onAddTableCatalog,
|
||||
onRemoveTableCatalog,
|
||||
onExtraInputChange,
|
||||
},
|
||||
validationErrors,
|
||||
getValidation,
|
||||
clearValidationErrors,
|
||||
db,
|
||||
key: field,
|
||||
field,
|
||||
isEditMode,
|
||||
sslForced,
|
||||
editNewDb,
|
||||
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
</Form>
|
||||
);
|
||||
};
|
||||
export const FormFieldMap = FORM_FIELD_MAP;
|
||||
|
||||
export default DatabaseConnectionForm;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import React, { EventHandler, ChangeEvent, useState } from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { t, styled } from '@superset-ui/core';
|
||||
import { AntdForm, Col, Row } from 'src/components';
|
||||
import { Form, FormLabel } from 'src/components/Form';
|
||||
|
@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio';
|
|||
import { Input, TextArea } from 'src/components/Input';
|
||||
import { Input as AntdInput, Tooltip } from 'antd';
|
||||
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||
import { DatabaseObject } from '../types';
|
||||
import { DatabaseObject, FieldPropTypes } from '../types';
|
||||
import { AuthType } from '.';
|
||||
|
||||
const StyledDiv = styled.div`
|
||||
|
@ -54,9 +54,7 @@ const SSHTunnelForm = ({
|
|||
setSSHTunnelLoginMethod,
|
||||
}: {
|
||||
db: DatabaseObject | null;
|
||||
onSSHTunnelParametersChange: EventHandler<
|
||||
ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
|
||||
>;
|
||||
onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange'];
|
||||
setSSHTunnelLoginMethod: (method: AuthType) => void;
|
||||
}) => {
|
||||
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
|
||||
|
@ -86,9 +84,9 @@ const SSHTunnelForm = ({
|
|||
</FormLabel>
|
||||
<Input
|
||||
name="server_port"
|
||||
type="text"
|
||||
placeholder={t('22')}
|
||||
value={db?.ssh_tunnel?.server_port || ''}
|
||||
type="number"
|
||||
value={db?.ssh_tunnel?.server_port}
|
||||
onChange={onSSHTunnelParametersChange}
|
||||
data-test="ssh-tunnel-server_port-input"
|
||||
/>
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import React from 'react';
|
||||
import { render, screen } from 'spec/helpers/testing-library';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import SSHTunnelSwitch from './SSHTunnelSwitch';
|
||||
import { DatabaseForm, DatabaseObject } from '../types';
|
||||
|
||||
jest.mock('@superset-ui/core', () => ({
|
||||
...jest.requireActual('@superset-ui/core'),
|
||||
isFeatureEnabled: jest.fn().mockReturnValue(true),
|
||||
}));
|
||||
|
||||
jest.mock('src/components', () => ({
|
||||
AntdSwitch: ({
|
||||
checked,
|
||||
onChange,
|
||||
}: {
|
||||
checked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
}) => (
|
||||
<button
|
||||
onClick={() => onChange(!checked)}
|
||||
aria-checked={checked}
|
||||
role="switch"
|
||||
type="button"
|
||||
>
|
||||
{checked ? 'ON' : 'OFF'}
|
||||
</button>
|
||||
),
|
||||
}));
|
||||
|
||||
const mockChangeMethods = {
|
||||
onParametersChange: jest.fn(),
|
||||
};
|
||||
|
||||
const mockDbModel = {
|
||||
engine: 'mysql',
|
||||
engine_information: {
|
||||
disable_ssh_tunneling: false,
|
||||
},
|
||||
} as DatabaseForm;
|
||||
|
||||
const defaultDb = {
|
||||
parameters: { ssh: false },
|
||||
ssh_tunnel: {},
|
||||
engine: 'mysql',
|
||||
} as DatabaseObject;
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('Renders SSH Tunnel switch enabled by default and toggles its state', () => {
|
||||
render(
|
||||
<SSHTunnelSwitch
|
||||
changeMethods={mockChangeMethods}
|
||||
clearValidationErrors={jest.fn}
|
||||
db={defaultDb}
|
||||
dbModel={mockDbModel}
|
||||
/>,
|
||||
);
|
||||
const switchButton = screen.getByRole('switch');
|
||||
expect(switchButton).toHaveTextContent('OFF');
|
||||
userEvent.click(switchButton);
|
||||
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
|
||||
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
|
||||
});
|
||||
expect(switchButton).toHaveTextContent('ON');
|
||||
});
|
||||
|
||||
test('Does not render if SSH Tunnel is disabled', () => {
|
||||
render(
|
||||
<SSHTunnelSwitch
|
||||
changeMethods={mockChangeMethods}
|
||||
clearValidationErrors={jest.fn}
|
||||
db={defaultDb}
|
||||
dbModel={{
|
||||
...mockDbModel,
|
||||
engine_information: {
|
||||
disable_ssh_tunneling: true,
|
||||
supports_file_upload: false,
|
||||
},
|
||||
}}
|
||||
/>,
|
||||
);
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('Checks the switch based on db.parameters.ssh', () => {
|
||||
const dbWithSSHTunnelEnabled = {
|
||||
...defaultDb,
|
||||
parameters: { ssh: true },
|
||||
} as DatabaseObject;
|
||||
render(
|
||||
<SSHTunnelSwitch
|
||||
changeMethods={mockChangeMethods}
|
||||
clearValidationErrors={jest.fn}
|
||||
db={dbWithSSHTunnelEnabled}
|
||||
dbModel={mockDbModel}
|
||||
/>,
|
||||
);
|
||||
expect(screen.getByRole('switch')).toHaveTextContent('ON');
|
||||
});
|
||||
|
||||
test('Calls onParametersChange with true if SSH Tunnel info exists', () => {
|
||||
const dbWithSSHTunnelInfo = {
|
||||
...defaultDb,
|
||||
parameters: { ssh: undefined },
|
||||
ssh_tunnel: { host: 'example.com' },
|
||||
} as DatabaseObject;
|
||||
render(
|
||||
<SSHTunnelSwitch
|
||||
changeMethods={mockChangeMethods}
|
||||
clearValidationErrors={jest.fn}
|
||||
db={dbWithSSHTunnelInfo}
|
||||
dbModel={mockDbModel}
|
||||
/>,
|
||||
);
|
||||
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
|
||||
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
|
||||
});
|
||||
});
|
||||
|
||||
test('Displays tooltip text on hover over the InfoTooltip', async () => {
|
||||
const tooltipText = 'SSH Tunnel configuration parameters';
|
||||
render(
|
||||
<SSHTunnelSwitch
|
||||
changeMethods={mockChangeMethods}
|
||||
clearValidationErrors={jest.fn}
|
||||
db={defaultDb}
|
||||
dbModel={mockDbModel}
|
||||
/>,
|
||||
);
|
||||
|
||||
const infoTooltipTrigger = screen.getByRole('img', {
|
||||
name: 'info-solid_small',
|
||||
});
|
||||
expect(infoTooltipTrigger).toBeInTheDocument();
|
||||
|
||||
userEvent.hover(infoTooltipTrigger);
|
||||
|
||||
const tooltip = await screen.findByText(tooltipText);
|
||||
|
||||
expect(tooltip).toBeInTheDocument();
|
||||
});
|
|
@ -16,35 +16,73 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import React from 'react';
|
||||
import { t, SupersetTheme, SwitchProps } from '@superset-ui/core';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
t,
|
||||
SupersetTheme,
|
||||
isFeatureEnabled,
|
||||
FeatureFlag,
|
||||
} from '@superset-ui/core';
|
||||
import { AntdSwitch } from 'src/components';
|
||||
import InfoTooltip from 'src/components/InfoTooltip';
|
||||
import { isEmpty } from 'lodash';
|
||||
import { ActionType } from '.';
|
||||
import { infoTooltip, toggleStyle } from './styles';
|
||||
import { SwitchProps } from '../types';
|
||||
|
||||
const SSHTunnelSwitch = ({
|
||||
isEditMode,
|
||||
dbFetched,
|
||||
useSSHTunneling,
|
||||
setUseSSHTunneling,
|
||||
setDB,
|
||||
isSSHTunneling,
|
||||
}: SwitchProps) =>
|
||||
isSSHTunneling ? (
|
||||
clearValidationErrors,
|
||||
changeMethods,
|
||||
db,
|
||||
dbModel,
|
||||
}: SwitchProps) => {
|
||||
const [isChecked, setChecked] = useState(false);
|
||||
const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling);
|
||||
const disableSSHTunnelingForEngine =
|
||||
dbModel?.engine_information?.disable_ssh_tunneling || false;
|
||||
const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine;
|
||||
|
||||
const handleOnChange = (changed: boolean) => {
|
||||
setChecked(changed);
|
||||
changeMethods.onParametersChange({
|
||||
target: {
|
||||
type: 'toggle',
|
||||
name: 'ssh',
|
||||
checked: true,
|
||||
value: changed,
|
||||
},
|
||||
});
|
||||
clearValidationErrors();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) {
|
||||
setChecked(db.parameters.ssh);
|
||||
}
|
||||
}, [db?.parameters?.ssh, isSSHTunnelEnabled]);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
isSSHTunnelEnabled &&
|
||||
db?.parameters?.ssh === undefined &&
|
||||
!isEmpty(db?.ssh_tunnel)
|
||||
) {
|
||||
// reflecting the state of the ssh tunnel on first load
|
||||
changeMethods.onParametersChange({
|
||||
target: {
|
||||
type: 'toggle',
|
||||
name: 'ssh',
|
||||
checked: true,
|
||||
value: true,
|
||||
},
|
||||
});
|
||||
}
|
||||
}, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]);
|
||||
|
||||
return isSSHTunnelEnabled ? (
|
||||
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
|
||||
<AntdSwitch
|
||||
disabled={isEditMode && !isEmpty(dbFetched?.ssh_tunnel)}
|
||||
checked={useSSHTunneling}
|
||||
onChange={changed => {
|
||||
setUseSSHTunneling(changed);
|
||||
if (!changed) {
|
||||
setDB({
|
||||
type: ActionType.RemoveSSHTunnelConfig,
|
||||
});
|
||||
}
|
||||
}}
|
||||
checked={isChecked}
|
||||
onChange={handleOnChange}
|
||||
data-test="ssh-tunnel-switch"
|
||||
/>
|
||||
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
|
||||
|
@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({
|
|||
/>
|
||||
</div>
|
||||
) : null;
|
||||
};
|
||||
|
||||
export default SSHTunnelSwitch;
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
// TODO: These tests should be made atomic in separate files
|
||||
|
||||
import React from 'react';
|
||||
import fetchMock from 'fetch-mock';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
|
@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => {
|
|||
const SSHTunnelServerPortInput = screen.getByTestId(
|
||||
'ssh-tunnel-server_port-input',
|
||||
);
|
||||
expect(SSHTunnelServerPortInput).toHaveValue('');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue(null);
|
||||
userEvent.type(SSHTunnelServerPortInput, '22');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue('22');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue(22);
|
||||
const SSHTunnelUsernameInput = screen.getByTestId(
|
||||
'ssh-tunnel-username-input',
|
||||
);
|
||||
|
@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => {
|
|||
const SSHTunnelServerPortInput = screen.getByTestId(
|
||||
'ssh-tunnel-server_port-input',
|
||||
);
|
||||
expect(SSHTunnelServerPortInput).toHaveValue('');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue(null);
|
||||
userEvent.type(SSHTunnelServerPortInput, '22');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue('22');
|
||||
expect(SSHTunnelServerPortInput).toHaveValue(22);
|
||||
const SSHTunnelUsernameInput = screen.getByTestId(
|
||||
'ssh-tunnel-username-input',
|
||||
);
|
||||
|
|
|
@ -20,8 +20,6 @@ import {
|
|||
t,
|
||||
styled,
|
||||
SupersetTheme,
|
||||
FeatureFlag,
|
||||
isFeatureEnabled,
|
||||
getExtensionsRegistry,
|
||||
} from '@superset-ui/core';
|
||||
import React, {
|
||||
|
@ -31,6 +29,7 @@ import React, {
|
|||
useState,
|
||||
useReducer,
|
||||
Reducer,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
import { useHistory } from 'react-router-dom';
|
||||
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
|
||||
|
@ -65,6 +64,7 @@ import {
|
|||
CatalogObject,
|
||||
Engines,
|
||||
ExtraJson,
|
||||
CustomTextType,
|
||||
} from '../types';
|
||||
import ExtraOptions from './ExtraOptions';
|
||||
import SqlAlchemyForm from './SqlAlchemyForm';
|
||||
|
@ -208,8 +208,8 @@ export type DBReducerActionType =
|
|||
| {
|
||||
type:
|
||||
| ActionType.Reset
|
||||
| ActionType.AddTableCatalogSheet
|
||||
| ActionType.RemoveSSHTunnelConfig;
|
||||
| ActionType.RemoveSSHTunnelConfig
|
||||
| ActionType.AddTableCatalogSheet;
|
||||
}
|
||||
| {
|
||||
type: ActionType.RemoveTableCatalogSheet;
|
||||
|
@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
const SSHTunnelSwitchComponent =
|
||||
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
|
||||
|
||||
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean>(false);
|
||||
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
let dbConfigExtraExtension = extensionsRegistry.get(
|
||||
'databaseconnection.extraOption',
|
||||
|
@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
const dbImages = getDatabaseImages();
|
||||
const connectionAlert = getConnectionAlert();
|
||||
const isEditMode = !!databaseId;
|
||||
const disableSSHTunnelingForEngine = (
|
||||
availableDbs?.databases?.find(
|
||||
(DB: DatabaseObject) =>
|
||||
DB.backend === db?.engine || DB.engine === db?.engine,
|
||||
) as DatabaseObject
|
||||
)?.engine_information?.disable_ssh_tunneling;
|
||||
const isSSHTunneling =
|
||||
isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine;
|
||||
const hasAlert =
|
||||
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
|
||||
const useSqlAlchemyForm =
|
||||
|
@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
extra: db?.extra,
|
||||
masked_encrypted_extra: db?.masked_encrypted_extra || '',
|
||||
server_cert: db?.server_cert || undefined,
|
||||
ssh_tunnel: db?.ssh_tunnel || undefined,
|
||||
ssh_tunnel:
|
||||
!isEmpty(db?.ssh_tunnel) && useSSHTunneling
|
||||
? {
|
||||
...db.ssh_tunnel,
|
||||
server_port: Number(db.ssh_tunnel!.server_port),
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
setTestInProgress(true);
|
||||
testDatabaseConnection(
|
||||
|
@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
return false;
|
||||
};
|
||||
|
||||
const onChange = useCallback(
|
||||
(
|
||||
type: DBReducerActionType['type'],
|
||||
payload: CustomTextType | DBReducerPayloadType,
|
||||
) => {
|
||||
setDB({ type, payload } as DBReducerActionType);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const handleClearValidationErrors = useCallback(() => {
|
||||
setValidationErrors(null);
|
||||
}, [setValidationErrors]);
|
||||
|
||||
const handleParametersChange = useCallback(
|
||||
({ target }: { target: HTMLInputElement }) => {
|
||||
onChange(ActionType.ParametersChange, {
|
||||
type: target.type,
|
||||
name: target.name,
|
||||
checked: target.checked,
|
||||
value: target.value,
|
||||
});
|
||||
},
|
||||
[onChange],
|
||||
);
|
||||
|
||||
const onClose = () => {
|
||||
setDB({ type: ActionType.Reset });
|
||||
setHasConnectedDb(false);
|
||||
setValidationErrors(null); // reset validation errors on close
|
||||
handleClearValidationErrors(); // reset validation errors on close
|
||||
clearError();
|
||||
setEditNewDb(false);
|
||||
setFileList([]);
|
||||
|
@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
setSSHTunnelPrivateKeys({});
|
||||
setSSHTunnelPrivateKeyPasswords({});
|
||||
setConfirmedOverwrite(false);
|
||||
setUseSSHTunneling(false);
|
||||
setUseSSHTunneling(undefined);
|
||||
onHide();
|
||||
};
|
||||
|
||||
|
@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
setImportingErrorMessage(msg);
|
||||
});
|
||||
|
||||
const onChange = (type: any, payload: any) => {
|
||||
setDB({ type, payload } as DBReducerActionType);
|
||||
};
|
||||
|
||||
const onSave = async () => {
|
||||
let dbConfigExtraExtensionOnSaveError;
|
||||
|
||||
setLoading(true);
|
||||
|
||||
dbConfigExtraExtension
|
||||
?.onSave(extraExtensionComponentState, db)
|
||||
.then(({ error }: { error: any }) => {
|
||||
|
@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
addDangerToast(error);
|
||||
}
|
||||
});
|
||||
|
||||
if (dbConfigExtraExtensionOnSaveError) {
|
||||
setLoading(false);
|
||||
return;
|
||||
|
@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
});
|
||||
}
|
||||
|
||||
// only do validation for non ssh tunnel connections
|
||||
if (!dbToUpdate?.ssh_tunnel) {
|
||||
// make sure that button spinner animates
|
||||
setLoading(true);
|
||||
const errors = await getValidation(dbToUpdate, true);
|
||||
if ((validationErrors && !isEmpty(validationErrors)) || errors) {
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
// end spinner animation
|
||||
const errors = await getValidation(dbToUpdate, true);
|
||||
if (!isEmpty(validationErrors) || errors?.length) {
|
||||
addDangerToast(
|
||||
t('Connection failed, please check your connection settings.'),
|
||||
);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const parameters_schema = isEditMode
|
||||
|
@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
});
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
// strictly checking for false as an indication that the toggle got unchecked
|
||||
if (useSSHTunneling === false) {
|
||||
// remove ssh tunnel
|
||||
dbToUpdate.ssh_tunnel = null;
|
||||
}
|
||||
|
||||
if (db?.id) {
|
||||
const result = await updateResource(
|
||||
db.id as number,
|
||||
|
@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
}, [sshPrivateKeyPasswordNeeded]);
|
||||
|
||||
useEffect(() => {
|
||||
if (db && isSSHTunneling) {
|
||||
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
|
||||
if (db?.parameters?.ssh !== undefined) {
|
||||
setUseSSHTunneling(db.parameters.ssh);
|
||||
}
|
||||
}, [db, isSSHTunneling]);
|
||||
}, [db?.parameters?.ssh]);
|
||||
|
||||
const onDbImport = async (info: UploadChangeParam) => {
|
||||
setImportingErrorMessage('');
|
||||
|
@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
const renderSSHTunnelForm = () => (
|
||||
<SSHTunnelForm
|
||||
db={db as DatabaseObject}
|
||||
onSSHTunnelParametersChange={({
|
||||
target,
|
||||
}: {
|
||||
target: HTMLInputElement | HTMLTextAreaElement;
|
||||
}) =>
|
||||
onSSHTunnelParametersChange={({ target }) => {
|
||||
onChange(ActionType.ParametersSSHTunnelChange, {
|
||||
type: target.type,
|
||||
name: target.name,
|
||||
value: target.value,
|
||||
})
|
||||
}
|
||||
});
|
||||
handleClearValidationErrors();
|
||||
}}
|
||||
setSSHTunnelLoginMethod={(method: AuthType) =>
|
||||
setDB({
|
||||
type: ActionType.SetSSHTunnelLoginMethod,
|
||||
|
@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
payload: { indexToDelete: idx },
|
||||
});
|
||||
}}
|
||||
onParametersChange={({ target }: { target: HTMLInputElement }) =>
|
||||
onChange(ActionType.ParametersChange, {
|
||||
type: target.type,
|
||||
name: target.name,
|
||||
checked: target.checked,
|
||||
value: target.value,
|
||||
})
|
||||
}
|
||||
onParametersChange={handleParametersChange}
|
||||
onChange={({ target }: { target: HTMLInputElement }) =>
|
||||
onChange(ActionType.TextChange, {
|
||||
name: target.name,
|
||||
|
@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
getValidation={() => getValidation(db)}
|
||||
validationErrors={validationErrors}
|
||||
getPlaceholder={getPlaceholder}
|
||||
clearValidationErrors={() => setValidationErrors(null)}
|
||||
clearValidationErrors={handleClearValidationErrors}
|
||||
/>
|
||||
{db?.parameters?.ssh && (
|
||||
{useSSHTunneling && (
|
||||
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
|
||||
)}
|
||||
</>
|
||||
|
@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||
testInProgress={testInProgress}
|
||||
>
|
||||
<SSHTunnelSwitchComponent
|
||||
isEditMode={isEditMode}
|
||||
dbFetched={dbFetched}
|
||||
disableSSHTunnelingForEngine={disableSSHTunnelingForEngine}
|
||||
useSSHTunneling={useSSHTunneling}
|
||||
setUseSSHTunneling={setUseSSHTunneling}
|
||||
setDB={setDB}
|
||||
isSSHTunneling={isSSHTunneling}
|
||||
dbModel={dbModel}
|
||||
db={db as DatabaseObject}
|
||||
changeMethods={{
|
||||
onParametersChange: handleParametersChange,
|
||||
}}
|
||||
clearValidationErrors={handleClearValidationErrors}
|
||||
/>
|
||||
{useSSHTunneling && renderSSHTunnelForm()}
|
||||
</SqlAlchemyForm>
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import { JsonObject } from '@superset-ui/core';
|
||||
import { InputProps } from 'antd/lib/input';
|
||||
import { ChangeEvent, EventHandler, FormEvent } from 'react';
|
||||
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
|
@ -108,7 +112,7 @@ export type DatabaseObject = {
|
|||
};
|
||||
|
||||
// SSH Tunnel information
|
||||
ssh_tunnel?: SSHTunnelObject;
|
||||
ssh_tunnel?: SSHTunnelObject | null;
|
||||
};
|
||||
|
||||
export type DatabaseForm = {
|
||||
|
@ -195,6 +199,10 @@ export type DatabaseForm = {
|
|||
};
|
||||
preferred: boolean;
|
||||
sqlalchemy_uri_placeholder: string;
|
||||
engine_information: {
|
||||
supports_file_upload: boolean;
|
||||
disable_ssh_tunneling: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
// the values should align with the database
|
||||
|
@ -231,3 +239,73 @@ export interface ExtraJson {
|
|||
};
|
||||
version?: string;
|
||||
}
|
||||
|
||||
export type CustomTextType = {
|
||||
value?: string | boolean | number;
|
||||
type?: string | null;
|
||||
name?: string;
|
||||
checked?: boolean;
|
||||
};
|
||||
|
||||
type CustomHTMLInputElement = Omit<Partial<CustomTextType>, 'value' | 'type'> &
|
||||
CustomTextType;
|
||||
|
||||
type CustomHTMLTextAreaElement = Omit<
|
||||
Partial<CustomTextType>,
|
||||
'value' | 'type'
|
||||
> &
|
||||
CustomTextType;
|
||||
|
||||
export type CustomParametersChangeType<T = CustomTextType> =
|
||||
| FormEvent<InputProps>
|
||||
| { target: T };
|
||||
|
||||
export type CustomEventHandlerType = EventHandler<
|
||||
ChangeEvent<CustomHTMLInputElement | CustomHTMLTextAreaElement>
|
||||
>;
|
||||
|
||||
export interface FieldPropTypes {
|
||||
required: boolean;
|
||||
hasTooltip?: boolean;
|
||||
tooltipText?: (value: any) => string;
|
||||
placeholder?: string;
|
||||
onParametersChange: (event: CustomParametersChangeType) => void;
|
||||
onParametersUploadFileChange: (value: any) => string;
|
||||
changeMethods: {
|
||||
onParametersChange: (event: CustomParametersChangeType) => void;
|
||||
} & {
|
||||
onChange: (value: any) => string;
|
||||
} & {
|
||||
onQueryChange: (value: any) => string;
|
||||
} & { onParametersUploadFileChange: (value: any) => string } & {
|
||||
onAddTableCatalog: () => void;
|
||||
onRemoveTableCatalog: (idx: number) => void;
|
||||
} & {
|
||||
onExtraInputChange: (value: any) => void;
|
||||
onSSHTunnelParametersChange: CustomEventHandlerType;
|
||||
};
|
||||
validationErrors: JsonObject | null;
|
||||
getValidation: () => void;
|
||||
clearValidationErrors: () => void;
|
||||
db?: DatabaseObject;
|
||||
dbModel?: DatabaseForm;
|
||||
field: string;
|
||||
isEditMode?: boolean;
|
||||
sslForced?: boolean;
|
||||
defaultDBName?: string;
|
||||
editNewDb?: boolean;
|
||||
}
|
||||
|
||||
type ChangeMethodsType = FieldPropTypes['changeMethods'];
|
||||
|
||||
// changeMethods compatibility with dynamic forms
|
||||
type SwitchPropsChangeMethodsType = {
|
||||
onParametersChange: ChangeMethodsType['onParametersChange'];
|
||||
};
|
||||
|
||||
export type SwitchProps = {
|
||||
dbModel: DatabaseForm;
|
||||
db: DatabaseObject;
|
||||
changeMethods: SwitchPropsChangeMethodsType;
|
||||
clearValidationErrors: () => void;
|
||||
};
|
||||
|
|
|
@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart';
|
|||
import copyTextToClipboard from 'src/utils/copy';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import SupersetText from 'src/utils/textUtils';
|
||||
import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types';
|
||||
import { DatabaseObject } from 'src/features/databases/types';
|
||||
import { FavoriteStatus, ImportResourceName } from './types';
|
||||
|
||||
interface ListViewResourceState<D extends object = any> {
|
||||
loading: boolean;
|
||||
|
@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () =>
|
|||
SupersetText.DB_CONNECTION_DOC_LINKS;
|
||||
|
||||
export const testDatabaseConnection = (
|
||||
connection: DatabaseObject,
|
||||
connection: Partial<DatabaseObject>,
|
||||
handleErrorMsg: (errorMsg: string) => void,
|
||||
addSuccessToast: (arg0: string) => void,
|
||||
) => {
|
||||
|
@ -745,7 +746,7 @@ export function useDatabaseValidation() {
|
|||
const getValidation = useCallback(
|
||||
(database: Partial<DatabaseObject> | null, onCreate = false) => {
|
||||
if (database?.parameters?.ssh) {
|
||||
// when ssh tunnel is enabled we don't want to render any validation errors
|
||||
// TODO: /validate_parameters/ and related utils should support ssh tunnel
|
||||
setValidationErrors(null);
|
||||
return [];
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Optional
|
|||
|
||||
from flask import current_app
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
|
@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
|
|||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
|
@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
try:
|
||||
# Test connection before starting create transaction
|
||||
TestConnectionDatabaseCommand(self._properties).run()
|
||||
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
|
||||
except (
|
||||
SupersetErrorsException,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
|
@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
SSHTunnelInvalidError,
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
) as ex:
|
||||
db.session.rollback()
|
||||
event_logger.log_with_context(
|
||||
|
@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
# Check database_name uniqueness
|
||||
if not DatabaseDAO.validate_uniqueness(database_name):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
|
||||
if exceptions:
|
||||
exception = DatabaseInvalidError()
|
||||
exception.extend(exceptions)
|
||||
|
|
|
@ -23,11 +23,13 @@ from marshmallow import ValidationError
|
|||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
)
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
|
||||
|
@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateSSHTunnelCommand(BaseCommand):
|
||||
_database: Database
|
||||
|
||||
def __init__(self, database: Database, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._properties["database"] = database
|
||||
self._database = database
|
||||
|
||||
def run(self) -> Model:
|
||||
try:
|
||||
|
@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand):
|
|||
server_address: Optional[str] = self._properties.get("server_address")
|
||||
server_port: Optional[int] = self._properties.get("server_port")
|
||||
username: Optional[str] = self._properties.get("username")
|
||||
password: Optional[str] = self._properties.get("password")
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
)
|
||||
url = make_url_safe(self._database.sqlalchemy_uri)
|
||||
if not url.port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
if not server_address:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
||||
if not server_port:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
||||
if not username:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
||||
if not private_key and not password:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("password"))
|
||||
if private_key_password and private_key is None:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||
if exceptions:
|
||||
|
|
|
@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
|
|||
message = _("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class SSHTunnelDatabasePortError(CommandInvalidError):
|
||||
message = _("A database port is required when connecting via SSH Tunnel.")
|
||||
|
||||
|
||||
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
||||
message = _("SSH Tunnel could not be updated.")
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
|
|||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
|
@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
|||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOUpdateFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand):
|
|||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
def run(self) -> Optional[Model]:
|
||||
self.validate()
|
||||
try:
|
||||
if self._model is not None: # So we dont get incompatible types error
|
||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||
if self._model is None:
|
||||
return None
|
||||
|
||||
# unset password if private key is provided
|
||||
if self._properties.get("private_key"):
|
||||
self._properties["password"] = None
|
||||
|
||||
# unset private key and password if password is provided
|
||||
if self._properties.get("password"):
|
||||
self._properties["private_key"] = None
|
||||
self._properties["private_key_password"] = None
|
||||
|
||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||
return tunnel
|
||||
except DAOUpdateFailedError as ex:
|
||||
raise SSHTunnelUpdateFailedError() from ex
|
||||
return tunnel
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SSHTunnelNotFoundError()
|
||||
|
||||
url = make_url_safe(self._model.database.sqlalchemy_uri)
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
|
@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
|
|||
raise SSHTunnelInvalidError(
|
||||
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
|
||||
)
|
||||
if not url.port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
|
|
|
@ -32,7 +32,10 @@ from superset.commands.database.exceptions import (
|
|||
DatabaseTestConnectionDriverError,
|
||||
DatabaseTestConnectionUnexpectedError,
|
||||
)
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
@ -61,20 +64,22 @@ def get_log_connection_action(
|
|||
|
||||
|
||||
class TestConnectionDatabaseCommand(BaseCommand):
|
||||
_model: Optional[Database] = None
|
||||
_context: dict[str, Any]
|
||||
_uri: str
|
||||
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
if (database_name := self._properties.get("database_name")) is not None:
|
||||
self._model = DatabaseDAO.get_database_by_name(database_name)
|
||||
|
||||
uri = self._properties.get("sqlalchemy_uri", "")
|
||||
if self._model and uri == self._model.safe_sqlalchemy_uri():
|
||||
uri = self._model.sqlalchemy_uri_decrypted
|
||||
ssh_tunnel = self._properties.get("ssh_tunnel")
|
||||
|
||||
# context for error messages
|
||||
url = make_url_safe(uri)
|
||||
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
|
@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
"database": url.database,
|
||||
}
|
||||
|
||||
self._context = context
|
||||
self._uri = uri
|
||||
|
||||
def run(self) -> None: # pylint: disable=too-many-statements
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
ssh_tunnel = self._properties.get("ssh_tunnel")
|
||||
|
||||
serialized_encrypted_extra = self._properties.get(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
|
@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
encrypted_extra=serialized_encrypted_extra,
|
||||
)
|
||||
|
||||
database.set_sqlalchemy_uri(uri)
|
||||
database.set_sqlalchemy_uri(self._uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
# Generate tunnel if present in the properties
|
||||
if ssh_tunnel:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
# If there's an existing tunnel for that DB we need to use the stored
|
||||
# password, private_key and private_key_password instead
|
||||
# unmask password while allowing for updated values
|
||||
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
|
||||
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
|
||||
ssh_tunnel = unmask_password_info(
|
||||
|
@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# check for custom errors (wrong username, wrong password, etc)
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||
raise SupersetErrorsException(errors) from ex
|
||||
except SupersetSecurityException as ex:
|
||||
event_logger.log_with_context(
|
||||
|
@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||
raise DatabaseTestConnectionUnexpectedError(errors) from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
if (database_name := self._properties.get("database_name")) is not None:
|
||||
self._model = DatabaseDAO.get_database_by_name(database_name)
|
||||
if self._properties.get("ssh_tunnel"):
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
if not self._context.get("port"):
|
||||
raise SSHTunnelDatabasePortError()
|
||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
|||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
|
@ -30,8 +31,11 @@ from superset.commands.database.exceptions import (
|
|||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
|
@ -47,15 +51,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDatabaseCommand(BaseCommand):
|
||||
_model: Optional[Database]
|
||||
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
|
||||
self.validate()
|
||||
|
||||
old_database_name = self._model.database_name
|
||||
|
||||
# unmask ``encrypted_extra``
|
||||
|
@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
database = DatabaseDAO.update(self._model, self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
|
||||
if "ssh_tunnel" in self._properties:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
if existing_ssh_tunnel_model is None:
|
||||
# We couldn't found an existing tunnel so we need to create one
|
||||
|
||||
if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
|
||||
# We need to remove the existing tunnel
|
||||
try:
|
||||
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
else:
|
||||
# We found an existing tunnel so we need to update it
|
||||
try:
|
||||
UpdateSSHTunnelCommand(
|
||||
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
DeleteSSHTunnelCommand(ssh_tunnel.id).run()
|
||||
ssh_tunnel = None
|
||||
except SSHTunnelDeleteFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
if ssh_tunnel is None:
|
||||
# We couldn't found an existing tunnel so we need to create one
|
||||
try:
|
||||
ssh_tunnel = CreateSSHTunnelCommand(
|
||||
database, ssh_tunnel_properties
|
||||
).run()
|
||||
except (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
else:
|
||||
# We found an existing tunnel so we need to update it
|
||||
try:
|
||||
ssh_tunnel_id = ssh_tunnel.id
|
||||
ssh_tunnel = UpdateSSHTunnelCommand(
|
||||
ssh_tunnel_id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
|
||||
# adding a new database we always want to force refresh schema list
|
||||
# TODO Improve this simplistic implementation for catching DB conn fails
|
||||
try:
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
|
@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
|
||||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if database_name:
|
||||
# Check database_name uniqueness
|
||||
|
|
|
@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
|
|||
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
|
||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
|
@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>", methods=("DELETE",))
|
||||
|
@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
try:
|
||||
TestConnectionDatabaseCommand(item).run()
|
||||
return self.response(200, message="OK")
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>/related_objects/", methods=("GET",))
|
||||
|
|
|
@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
|
|||
from sqlalchemy.sql import func
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_create_database_with_missing_port_raises_error(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
|
||||
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(
|
||||
response.get("message"),
|
||||
"A database port is required when connecting via SSH Tunnel.",
|
||||
)
|
||||
|
||||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
|
@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_database_with_missing_port_raises_error(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
mock_update_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
|
||||
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response_create = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
|
||||
uri = "api/v1/database/{}".format(response_create.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(
|
||||
response.get("message"),
|
||||
"A database port is required when connecting via SSH Tunnel.",
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response_create.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_delete_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_delete_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test deleting a SSH tunnel via Database update
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
mock_update_is_feature_enabled.return_value = True
|
||||
mock_delete_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
|
||||
database_data_with_ssh_tunnel_null = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": None,
|
||||
}
|
||||
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
assert model_ssh_tunnel is None
|
||||
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
|
|
|
@ -19,7 +19,10 @@
|
|||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command() -> None:
|
||||
|
@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
|
|||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||
)
|
||||
|
||||
properties = {
|
||||
"database_id": database.id,
|
||||
|
@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
|||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||
)
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
|
@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
|||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_no_port() -> None:
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost/db",
|
||||
)
|
||||
|
||||
properties = {
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(database, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == (
|
||||
"A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
|
|
|
@ -20,11 +20,14 @@ from collections.abc import Iterator
|
|||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
def session_with_data(request, session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
|||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
|
@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
|||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
|
||||
)
|
||||
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test update"}
|
||||
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||
|
||||
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == (
|
||||
"A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue