mirror of https://github.com/apache/superset.git
fix: SSH Tunnel configuration settings (#27186)
This commit is contained in:
parent
fde93dcf08
commit
89e89de341
|
@ -44,15 +44,15 @@ interface MenuObjectChildProps {
|
||||||
disable?: boolean;
|
disable?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SwitchProps {
|
// loose typing to avoid any circular dependencies
|
||||||
isEditMode: boolean;
|
// refer to SSHTunnelSwitch component for strict typing
|
||||||
dbFetched: any;
|
type SwitchProps = {
|
||||||
disableSSHTunnelingForEngine?: boolean;
|
db: object;
|
||||||
useSSHTunneling: boolean;
|
changeMethods: {
|
||||||
setUseSSHTunneling: React.Dispatch<React.SetStateAction<boolean>>;
|
onParametersChange: (event: any) => void;
|
||||||
setDB: React.Dispatch<any>;
|
};
|
||||||
isSSHTunneling: boolean;
|
clearValidationErrors: () => void;
|
||||||
}
|
};
|
||||||
|
|
||||||
type ConfigDetailsProps = {
|
type ConfigDetailsProps = {
|
||||||
embeddedId: string;
|
embeddedId: string;
|
||||||
|
|
|
@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => {
|
||||||
useRedux: true,
|
useRedux: true,
|
||||||
});
|
});
|
||||||
userEvent.click(screen.getByTestId('schedule-panel'));
|
userEvent.click(screen.getByTestId('schedule-panel'));
|
||||||
const days = screen.getAllByTitle(/day/i, { exact: true });
|
const day = screen.getByText('day');
|
||||||
expect(days.length).toBe(2);
|
expect(day).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notification Method Section
|
// Notification Method Section
|
||||||
|
|
|
@ -17,12 +17,11 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { isEmpty } from 'lodash';
|
|
||||||
import { SupersetTheme, t } from '@superset-ui/core';
|
import { SupersetTheme, t } from '@superset-ui/core';
|
||||||
import { AntdSwitch } from 'src/components';
|
import { AntdSwitch } from 'src/components';
|
||||||
import InfoTooltip from 'src/components/InfoTooltip';
|
import InfoTooltip from 'src/components/InfoTooltip';
|
||||||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||||
import { FieldPropTypes } from '.';
|
import { FieldPropTypes } from '../../types';
|
||||||
import { toggleStyle, infoTooltip } from '../styles';
|
import { toggleStyle, infoTooltip } from '../styles';
|
||||||
|
|
||||||
export const hostField = ({
|
export const hostField = ({
|
||||||
|
@ -252,35 +251,3 @@ export const forceSSLField = ({
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
||||||
export const SSHTunnelSwitch = ({
|
|
||||||
isEditMode,
|
|
||||||
changeMethods,
|
|
||||||
clearValidationErrors,
|
|
||||||
db,
|
|
||||||
}: FieldPropTypes) => (
|
|
||||||
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
|
|
||||||
<AntdSwitch
|
|
||||||
disabled={isEditMode && !isEmpty(db?.ssh_tunnel)}
|
|
||||||
checked={db?.parameters?.ssh}
|
|
||||||
onChange={changed => {
|
|
||||||
changeMethods.onParametersChange({
|
|
||||||
target: {
|
|
||||||
type: 'toggle',
|
|
||||||
name: 'ssh',
|
|
||||||
checked: true,
|
|
||||||
value: changed,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
clearValidationErrors();
|
|
||||||
}}
|
|
||||||
data-test="ssh-tunnel-switch"
|
|
||||||
/>
|
|
||||||
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
|
|
||||||
<InfoTooltip
|
|
||||||
tooltip={t('SSH Tunnel configuration parameters')}
|
|
||||||
placement="right"
|
|
||||||
viewBox="0 -5 24 24"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components';
|
||||||
import InfoTooltip from 'src/components/InfoTooltip';
|
import InfoTooltip from 'src/components/InfoTooltip';
|
||||||
import FormLabel from 'src/components/Form/FormLabel';
|
import FormLabel from 'src/components/Form/FormLabel';
|
||||||
import Icons from 'src/components/Icons';
|
import Icons from 'src/components/Icons';
|
||||||
import { FieldPropTypes } from '.';
|
import { FieldPropTypes } from '../../types';
|
||||||
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
|
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
|
||||||
|
|
||||||
enum CredentialInfoOptions {
|
enum CredentialInfoOptions {
|
||||||
|
|
|
@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core';
|
||||||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||||
import FormLabel from 'src/components/Form/FormLabel';
|
import FormLabel from 'src/components/Form/FormLabel';
|
||||||
import Icons from 'src/components/Icons';
|
import Icons from 'src/components/Icons';
|
||||||
import { FieldPropTypes } from '.';
|
|
||||||
import { StyledFooterButton, StyledCatalogTable } from '../styles';
|
import { StyledFooterButton, StyledCatalogTable } from '../styles';
|
||||||
import { CatalogObject } from '../../types';
|
import { CatalogObject, FieldPropTypes } from '../../types';
|
||||||
|
|
||||||
export const TableCatalog = ({
|
export const TableCatalog = ({
|
||||||
required,
|
required,
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { t } from '@superset-ui/core';
|
import { t } from '@superset-ui/core';
|
||||||
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
|
||||||
import { FieldPropTypes } from '.';
|
import { FieldPropTypes } from '../../types';
|
||||||
|
|
||||||
const FIELD_TEXT_MAP = {
|
const FIELD_TEXT_MAP = {
|
||||||
account: {
|
account: {
|
||||||
|
|
|
@ -17,7 +17,11 @@
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import React, { FormEvent } from 'react';
|
import React, { FormEvent } from 'react';
|
||||||
import { SupersetTheme, JsonObject } from '@superset-ui/core';
|
import {
|
||||||
|
SupersetTheme,
|
||||||
|
JsonObject,
|
||||||
|
getExtensionsRegistry,
|
||||||
|
} from '@superset-ui/core';
|
||||||
import { InputProps } from 'antd/lib/input';
|
import { InputProps } from 'antd/lib/input';
|
||||||
import { Form } from 'src/components/Form';
|
import { Form } from 'src/components/Form';
|
||||||
import {
|
import {
|
||||||
|
@ -31,13 +35,13 @@ import {
|
||||||
portField,
|
portField,
|
||||||
queryField,
|
queryField,
|
||||||
usernameField,
|
usernameField,
|
||||||
SSHTunnelSwitch,
|
|
||||||
} from './CommonParameters';
|
} from './CommonParameters';
|
||||||
import { validatedInputField } from './ValidatedInputField';
|
import { validatedInputField } from './ValidatedInputField';
|
||||||
import { EncryptedField } from './EncryptedField';
|
import { EncryptedField } from './EncryptedField';
|
||||||
import { TableCatalog } from './TableCatalog';
|
import { TableCatalog } from './TableCatalog';
|
||||||
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
||||||
import { DatabaseForm, DatabaseObject } from '../../types';
|
import { DatabaseForm, DatabaseObject } from '../../types';
|
||||||
|
import SSHTunnelSwitch from '../SSHTunnelSwitch';
|
||||||
|
|
||||||
export const FormFieldOrder = [
|
export const FormFieldOrder = [
|
||||||
'host',
|
'host',
|
||||||
|
@ -59,34 +63,10 @@ export const FormFieldOrder = [
|
||||||
'ssh',
|
'ssh',
|
||||||
];
|
];
|
||||||
|
|
||||||
export interface FieldPropTypes {
|
const extensionsRegistry = getExtensionsRegistry();
|
||||||
required: boolean;
|
|
||||||
hasTooltip?: boolean;
|
const SSHTunnelSwitchComponent =
|
||||||
tooltipText?: (value: any) => string;
|
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
|
||||||
placeholder?: string;
|
|
||||||
onParametersChange: (value: any) => string;
|
|
||||||
onParametersUploadFileChange: (value: any) => string;
|
|
||||||
changeMethods: { onParametersChange: (value: any) => string } & {
|
|
||||||
onChange: (value: any) => string;
|
|
||||||
} & {
|
|
||||||
onQueryChange: (value: any) => string;
|
|
||||||
} & { onParametersUploadFileChange: (value: any) => string } & {
|
|
||||||
onAddTableCatalog: () => void;
|
|
||||||
onRemoveTableCatalog: (idx: number) => void;
|
|
||||||
} & {
|
|
||||||
onExtraInputChange: (value: any) => void;
|
|
||||||
onSSHTunnelParametersChange: (value: any) => string;
|
|
||||||
};
|
|
||||||
validationErrors: JsonObject | null;
|
|
||||||
getValidation: () => void;
|
|
||||||
clearValidationErrors: () => void;
|
|
||||||
db?: DatabaseObject;
|
|
||||||
field: string;
|
|
||||||
isEditMode?: boolean;
|
|
||||||
sslForced?: boolean;
|
|
||||||
defaultDBName?: string;
|
|
||||||
editNewDb?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
const FORM_FIELD_MAP = {
|
const FORM_FIELD_MAP = {
|
||||||
host: hostField,
|
host: hostField,
|
||||||
|
@ -105,7 +85,7 @@ const FORM_FIELD_MAP = {
|
||||||
warehouse: validatedInputField,
|
warehouse: validatedInputField,
|
||||||
role: validatedInputField,
|
role: validatedInputField,
|
||||||
account: validatedInputField,
|
account: validatedInputField,
|
||||||
ssh: SSHTunnelSwitch,
|
ssh: SSHTunnelSwitchComponent,
|
||||||
};
|
};
|
||||||
|
|
||||||
interface DatabaseConnectionFormProps {
|
interface DatabaseConnectionFormProps {
|
||||||
|
@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
const DatabaseConnectionForm = ({
|
const DatabaseConnectionForm = ({
|
||||||
dbModel: { parameters },
|
dbModel,
|
||||||
db,
|
db,
|
||||||
editNewDb,
|
editNewDb,
|
||||||
getPlaceholder,
|
getPlaceholder,
|
||||||
|
@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({
|
||||||
sslForced,
|
sslForced,
|
||||||
validationErrors,
|
validationErrors,
|
||||||
clearValidationErrors,
|
clearValidationErrors,
|
||||||
}: DatabaseConnectionFormProps) => (
|
}: DatabaseConnectionFormProps) => {
|
||||||
<Form>
|
const parameters = dbModel?.parameters;
|
||||||
<div
|
|
||||||
// @ts-ignore
|
return (
|
||||||
css={(theme: SupersetTheme) => [
|
<Form>
|
||||||
formScrollableStyles,
|
<div
|
||||||
validatedFormStyles(theme),
|
// @ts-ignore
|
||||||
]}
|
css={(theme: SupersetTheme) => [
|
||||||
>
|
formScrollableStyles,
|
||||||
{parameters &&
|
validatedFormStyles(theme),
|
||||||
FormFieldOrder.filter(
|
]}
|
||||||
(key: string) =>
|
>
|
||||||
Object.keys(parameters.properties).includes(key) ||
|
{parameters &&
|
||||||
key === 'database_name',
|
FormFieldOrder.filter(
|
||||||
).map(field =>
|
(key: string) =>
|
||||||
FORM_FIELD_MAP[field]({
|
Object.keys(parameters.properties).includes(key) ||
|
||||||
required: parameters.required?.includes(field),
|
key === 'database_name',
|
||||||
changeMethods: {
|
).map(field =>
|
||||||
onParametersChange,
|
FORM_FIELD_MAP[field]({
|
||||||
onChange,
|
required: parameters.required?.includes(field),
|
||||||
onQueryChange,
|
changeMethods: {
|
||||||
onParametersUploadFileChange,
|
onParametersChange,
|
||||||
onAddTableCatalog,
|
onChange,
|
||||||
onRemoveTableCatalog,
|
onQueryChange,
|
||||||
onExtraInputChange,
|
onParametersUploadFileChange,
|
||||||
},
|
onAddTableCatalog,
|
||||||
validationErrors,
|
onRemoveTableCatalog,
|
||||||
getValidation,
|
onExtraInputChange,
|
||||||
clearValidationErrors,
|
},
|
||||||
db,
|
validationErrors,
|
||||||
key: field,
|
getValidation,
|
||||||
field,
|
clearValidationErrors,
|
||||||
isEditMode,
|
db,
|
||||||
sslForced,
|
key: field,
|
||||||
editNewDb,
|
field,
|
||||||
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
|
isEditMode,
|
||||||
}),
|
sslForced,
|
||||||
)}
|
editNewDb,
|
||||||
</div>
|
placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
|
||||||
</Form>
|
}),
|
||||||
);
|
)}
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
);
|
||||||
|
};
|
||||||
export const FormFieldMap = FORM_FIELD_MAP;
|
export const FormFieldMap = FORM_FIELD_MAP;
|
||||||
|
|
||||||
export default DatabaseConnectionForm;
|
export default DatabaseConnectionForm;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import React, { EventHandler, ChangeEvent, useState } from 'react';
|
import React, { useState } from 'react';
|
||||||
import { t, styled } from '@superset-ui/core';
|
import { t, styled } from '@superset-ui/core';
|
||||||
import { AntdForm, Col, Row } from 'src/components';
|
import { AntdForm, Col, Row } from 'src/components';
|
||||||
import { Form, FormLabel } from 'src/components/Form';
|
import { Form, FormLabel } from 'src/components/Form';
|
||||||
|
@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio';
|
||||||
import { Input, TextArea } from 'src/components/Input';
|
import { Input, TextArea } from 'src/components/Input';
|
||||||
import { Input as AntdInput, Tooltip } from 'antd';
|
import { Input as AntdInput, Tooltip } from 'antd';
|
||||||
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||||
import { DatabaseObject } from '../types';
|
import { DatabaseObject, FieldPropTypes } from '../types';
|
||||||
import { AuthType } from '.';
|
import { AuthType } from '.';
|
||||||
|
|
||||||
const StyledDiv = styled.div`
|
const StyledDiv = styled.div`
|
||||||
|
@ -54,9 +54,7 @@ const SSHTunnelForm = ({
|
||||||
setSSHTunnelLoginMethod,
|
setSSHTunnelLoginMethod,
|
||||||
}: {
|
}: {
|
||||||
db: DatabaseObject | null;
|
db: DatabaseObject | null;
|
||||||
onSSHTunnelParametersChange: EventHandler<
|
onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange'];
|
||||||
ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
|
|
||||||
>;
|
|
||||||
setSSHTunnelLoginMethod: (method: AuthType) => void;
|
setSSHTunnelLoginMethod: (method: AuthType) => void;
|
||||||
}) => {
|
}) => {
|
||||||
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
|
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
|
||||||
|
@ -86,9 +84,9 @@ const SSHTunnelForm = ({
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
name="server_port"
|
name="server_port"
|
||||||
type="text"
|
|
||||||
placeholder={t('22')}
|
placeholder={t('22')}
|
||||||
value={db?.ssh_tunnel?.server_port || ''}
|
type="number"
|
||||||
|
value={db?.ssh_tunnel?.server_port}
|
||||||
onChange={onSSHTunnelParametersChange}
|
onChange={onSSHTunnelParametersChange}
|
||||||
data-test="ssh-tunnel-server_port-input"
|
data-test="ssh-tunnel-server_port-input"
|
||||||
/>
|
/>
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
import React from 'react';
|
||||||
|
import { render, screen } from 'spec/helpers/testing-library';
|
||||||
|
import userEvent from '@testing-library/user-event';
|
||||||
|
import SSHTunnelSwitch from './SSHTunnelSwitch';
|
||||||
|
import { DatabaseForm, DatabaseObject } from '../types';
|
||||||
|
|
||||||
|
jest.mock('@superset-ui/core', () => ({
|
||||||
|
...jest.requireActual('@superset-ui/core'),
|
||||||
|
isFeatureEnabled: jest.fn().mockReturnValue(true),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('src/components', () => ({
|
||||||
|
AntdSwitch: ({
|
||||||
|
checked,
|
||||||
|
onChange,
|
||||||
|
}: {
|
||||||
|
checked: boolean;
|
||||||
|
onChange: (checked: boolean) => void;
|
||||||
|
}) => (
|
||||||
|
<button
|
||||||
|
onClick={() => onChange(!checked)}
|
||||||
|
aria-checked={checked}
|
||||||
|
role="switch"
|
||||||
|
type="button"
|
||||||
|
>
|
||||||
|
{checked ? 'ON' : 'OFF'}
|
||||||
|
</button>
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockChangeMethods = {
|
||||||
|
onParametersChange: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockDbModel = {
|
||||||
|
engine: 'mysql',
|
||||||
|
engine_information: {
|
||||||
|
disable_ssh_tunneling: false,
|
||||||
|
},
|
||||||
|
} as DatabaseForm;
|
||||||
|
|
||||||
|
const defaultDb = {
|
||||||
|
parameters: { ssh: false },
|
||||||
|
ssh_tunnel: {},
|
||||||
|
engine: 'mysql',
|
||||||
|
} as DatabaseObject;
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Renders SSH Tunnel switch enabled by default and toggles its state', () => {
|
||||||
|
render(
|
||||||
|
<SSHTunnelSwitch
|
||||||
|
changeMethods={mockChangeMethods}
|
||||||
|
clearValidationErrors={jest.fn}
|
||||||
|
db={defaultDb}
|
||||||
|
dbModel={mockDbModel}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
const switchButton = screen.getByRole('switch');
|
||||||
|
expect(switchButton).toHaveTextContent('OFF');
|
||||||
|
userEvent.click(switchButton);
|
||||||
|
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
|
||||||
|
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
|
||||||
|
});
|
||||||
|
expect(switchButton).toHaveTextContent('ON');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Does not render if SSH Tunnel is disabled', () => {
|
||||||
|
render(
|
||||||
|
<SSHTunnelSwitch
|
||||||
|
changeMethods={mockChangeMethods}
|
||||||
|
clearValidationErrors={jest.fn}
|
||||||
|
db={defaultDb}
|
||||||
|
dbModel={{
|
||||||
|
...mockDbModel,
|
||||||
|
engine_information: {
|
||||||
|
disable_ssh_tunneling: true,
|
||||||
|
supports_file_upload: false,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(screen.queryByRole('switch')).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Checks the switch based on db.parameters.ssh', () => {
|
||||||
|
const dbWithSSHTunnelEnabled = {
|
||||||
|
...defaultDb,
|
||||||
|
parameters: { ssh: true },
|
||||||
|
} as DatabaseObject;
|
||||||
|
render(
|
||||||
|
<SSHTunnelSwitch
|
||||||
|
changeMethods={mockChangeMethods}
|
||||||
|
clearValidationErrors={jest.fn}
|
||||||
|
db={dbWithSSHTunnelEnabled}
|
||||||
|
dbModel={mockDbModel}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(screen.getByRole('switch')).toHaveTextContent('ON');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Calls onParametersChange with true if SSH Tunnel info exists', () => {
|
||||||
|
const dbWithSSHTunnelInfo = {
|
||||||
|
...defaultDb,
|
||||||
|
parameters: { ssh: undefined },
|
||||||
|
ssh_tunnel: { host: 'example.com' },
|
||||||
|
} as DatabaseObject;
|
||||||
|
render(
|
||||||
|
<SSHTunnelSwitch
|
||||||
|
changeMethods={mockChangeMethods}
|
||||||
|
clearValidationErrors={jest.fn}
|
||||||
|
db={dbWithSSHTunnelInfo}
|
||||||
|
dbModel={mockDbModel}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
|
||||||
|
target: { type: 'toggle', name: 'ssh', checked: true, value: true },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Displays tooltip text on hover over the InfoTooltip', async () => {
|
||||||
|
const tooltipText = 'SSH Tunnel configuration parameters';
|
||||||
|
render(
|
||||||
|
<SSHTunnelSwitch
|
||||||
|
changeMethods={mockChangeMethods}
|
||||||
|
clearValidationErrors={jest.fn}
|
||||||
|
db={defaultDb}
|
||||||
|
dbModel={mockDbModel}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
const infoTooltipTrigger = screen.getByRole('img', {
|
||||||
|
name: 'info-solid_small',
|
||||||
|
});
|
||||||
|
expect(infoTooltipTrigger).toBeInTheDocument();
|
||||||
|
|
||||||
|
userEvent.hover(infoTooltipTrigger);
|
||||||
|
|
||||||
|
const tooltip = await screen.findByText(tooltipText);
|
||||||
|
|
||||||
|
expect(tooltip).toBeInTheDocument();
|
||||||
|
});
|
|
@ -16,35 +16,73 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import React from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { t, SupersetTheme, SwitchProps } from '@superset-ui/core';
|
import {
|
||||||
|
t,
|
||||||
|
SupersetTheme,
|
||||||
|
isFeatureEnabled,
|
||||||
|
FeatureFlag,
|
||||||
|
} from '@superset-ui/core';
|
||||||
import { AntdSwitch } from 'src/components';
|
import { AntdSwitch } from 'src/components';
|
||||||
import InfoTooltip from 'src/components/InfoTooltip';
|
import InfoTooltip from 'src/components/InfoTooltip';
|
||||||
import { isEmpty } from 'lodash';
|
import { isEmpty } from 'lodash';
|
||||||
import { ActionType } from '.';
|
|
||||||
import { infoTooltip, toggleStyle } from './styles';
|
import { infoTooltip, toggleStyle } from './styles';
|
||||||
|
import { SwitchProps } from '../types';
|
||||||
|
|
||||||
const SSHTunnelSwitch = ({
|
const SSHTunnelSwitch = ({
|
||||||
isEditMode,
|
clearValidationErrors,
|
||||||
dbFetched,
|
changeMethods,
|
||||||
useSSHTunneling,
|
db,
|
||||||
setUseSSHTunneling,
|
dbModel,
|
||||||
setDB,
|
}: SwitchProps) => {
|
||||||
isSSHTunneling,
|
const [isChecked, setChecked] = useState(false);
|
||||||
}: SwitchProps) =>
|
const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling);
|
||||||
isSSHTunneling ? (
|
const disableSSHTunnelingForEngine =
|
||||||
|
dbModel?.engine_information?.disable_ssh_tunneling || false;
|
||||||
|
const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine;
|
||||||
|
|
||||||
|
const handleOnChange = (changed: boolean) => {
|
||||||
|
setChecked(changed);
|
||||||
|
changeMethods.onParametersChange({
|
||||||
|
target: {
|
||||||
|
type: 'toggle',
|
||||||
|
name: 'ssh',
|
||||||
|
checked: true,
|
||||||
|
value: changed,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
clearValidationErrors();
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) {
|
||||||
|
setChecked(db.parameters.ssh);
|
||||||
|
}
|
||||||
|
}, [db?.parameters?.ssh, isSSHTunnelEnabled]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
isSSHTunnelEnabled &&
|
||||||
|
db?.parameters?.ssh === undefined &&
|
||||||
|
!isEmpty(db?.ssh_tunnel)
|
||||||
|
) {
|
||||||
|
// reflecting the state of the ssh tunnel on first load
|
||||||
|
changeMethods.onParametersChange({
|
||||||
|
target: {
|
||||||
|
type: 'toggle',
|
||||||
|
name: 'ssh',
|
||||||
|
checked: true,
|
||||||
|
value: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]);
|
||||||
|
|
||||||
|
return isSSHTunnelEnabled ? (
|
||||||
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
|
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
|
||||||
<AntdSwitch
|
<AntdSwitch
|
||||||
disabled={isEditMode && !isEmpty(dbFetched?.ssh_tunnel)}
|
checked={isChecked}
|
||||||
checked={useSSHTunneling}
|
onChange={handleOnChange}
|
||||||
onChange={changed => {
|
|
||||||
setUseSSHTunneling(changed);
|
|
||||||
if (!changed) {
|
|
||||||
setDB({
|
|
||||||
type: ActionType.RemoveSSHTunnelConfig,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
data-test="ssh-tunnel-switch"
|
data-test="ssh-tunnel-switch"
|
||||||
/>
|
/>
|
||||||
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
|
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
|
||||||
|
@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
) : null;
|
) : null;
|
||||||
|
};
|
||||||
|
|
||||||
export default SSHTunnelSwitch;
|
export default SSHTunnelSwitch;
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// TODO: These tests should be made atomic in separate files
|
||||||
|
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import fetchMock from 'fetch-mock';
|
import fetchMock from 'fetch-mock';
|
||||||
import userEvent from '@testing-library/user-event';
|
import userEvent from '@testing-library/user-event';
|
||||||
|
@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => {
|
||||||
const SSHTunnelServerPortInput = screen.getByTestId(
|
const SSHTunnelServerPortInput = screen.getByTestId(
|
||||||
'ssh-tunnel-server_port-input',
|
'ssh-tunnel-server_port-input',
|
||||||
);
|
);
|
||||||
expect(SSHTunnelServerPortInput).toHaveValue('');
|
expect(SSHTunnelServerPortInput).toHaveValue(null);
|
||||||
userEvent.type(SSHTunnelServerPortInput, '22');
|
userEvent.type(SSHTunnelServerPortInput, '22');
|
||||||
expect(SSHTunnelServerPortInput).toHaveValue('22');
|
expect(SSHTunnelServerPortInput).toHaveValue(22);
|
||||||
const SSHTunnelUsernameInput = screen.getByTestId(
|
const SSHTunnelUsernameInput = screen.getByTestId(
|
||||||
'ssh-tunnel-username-input',
|
'ssh-tunnel-username-input',
|
||||||
);
|
);
|
||||||
|
@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => {
|
||||||
const SSHTunnelServerPortInput = screen.getByTestId(
|
const SSHTunnelServerPortInput = screen.getByTestId(
|
||||||
'ssh-tunnel-server_port-input',
|
'ssh-tunnel-server_port-input',
|
||||||
);
|
);
|
||||||
expect(SSHTunnelServerPortInput).toHaveValue('');
|
expect(SSHTunnelServerPortInput).toHaveValue(null);
|
||||||
userEvent.type(SSHTunnelServerPortInput, '22');
|
userEvent.type(SSHTunnelServerPortInput, '22');
|
||||||
expect(SSHTunnelServerPortInput).toHaveValue('22');
|
expect(SSHTunnelServerPortInput).toHaveValue(22);
|
||||||
const SSHTunnelUsernameInput = screen.getByTestId(
|
const SSHTunnelUsernameInput = screen.getByTestId(
|
||||||
'ssh-tunnel-username-input',
|
'ssh-tunnel-username-input',
|
||||||
);
|
);
|
||||||
|
|
|
@ -20,8 +20,6 @@ import {
|
||||||
t,
|
t,
|
||||||
styled,
|
styled,
|
||||||
SupersetTheme,
|
SupersetTheme,
|
||||||
FeatureFlag,
|
|
||||||
isFeatureEnabled,
|
|
||||||
getExtensionsRegistry,
|
getExtensionsRegistry,
|
||||||
} from '@superset-ui/core';
|
} from '@superset-ui/core';
|
||||||
import React, {
|
import React, {
|
||||||
|
@ -31,6 +29,7 @@ import React, {
|
||||||
useState,
|
useState,
|
||||||
useReducer,
|
useReducer,
|
||||||
Reducer,
|
Reducer,
|
||||||
|
useCallback,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { useHistory } from 'react-router-dom';
|
import { useHistory } from 'react-router-dom';
|
||||||
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
|
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
|
||||||
|
@ -65,6 +64,7 @@ import {
|
||||||
CatalogObject,
|
CatalogObject,
|
||||||
Engines,
|
Engines,
|
||||||
ExtraJson,
|
ExtraJson,
|
||||||
|
CustomTextType,
|
||||||
} from '../types';
|
} from '../types';
|
||||||
import ExtraOptions from './ExtraOptions';
|
import ExtraOptions from './ExtraOptions';
|
||||||
import SqlAlchemyForm from './SqlAlchemyForm';
|
import SqlAlchemyForm from './SqlAlchemyForm';
|
||||||
|
@ -208,8 +208,8 @@ export type DBReducerActionType =
|
||||||
| {
|
| {
|
||||||
type:
|
type:
|
||||||
| ActionType.Reset
|
| ActionType.Reset
|
||||||
| ActionType.AddTableCatalogSheet
|
| ActionType.RemoveSSHTunnelConfig
|
||||||
| ActionType.RemoveSSHTunnelConfig;
|
| ActionType.AddTableCatalogSheet;
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: ActionType.RemoveTableCatalogSheet;
|
type: ActionType.RemoveTableCatalogSheet;
|
||||||
|
@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
const SSHTunnelSwitchComponent =
|
const SSHTunnelSwitchComponent =
|
||||||
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
|
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
|
||||||
|
|
||||||
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean>(false);
|
const [useSSHTunneling, setUseSSHTunneling] = useState<boolean | undefined>(
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
|
||||||
let dbConfigExtraExtension = extensionsRegistry.get(
|
let dbConfigExtraExtension = extensionsRegistry.get(
|
||||||
'databaseconnection.extraOption',
|
'databaseconnection.extraOption',
|
||||||
|
@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
const dbImages = getDatabaseImages();
|
const dbImages = getDatabaseImages();
|
||||||
const connectionAlert = getConnectionAlert();
|
const connectionAlert = getConnectionAlert();
|
||||||
const isEditMode = !!databaseId;
|
const isEditMode = !!databaseId;
|
||||||
const disableSSHTunnelingForEngine = (
|
|
||||||
availableDbs?.databases?.find(
|
|
||||||
(DB: DatabaseObject) =>
|
|
||||||
DB.backend === db?.engine || DB.engine === db?.engine,
|
|
||||||
) as DatabaseObject
|
|
||||||
)?.engine_information?.disable_ssh_tunneling;
|
|
||||||
const isSSHTunneling =
|
|
||||||
isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine;
|
|
||||||
const hasAlert =
|
const hasAlert =
|
||||||
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
|
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
|
||||||
const useSqlAlchemyForm =
|
const useSqlAlchemyForm =
|
||||||
|
@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
extra: db?.extra,
|
extra: db?.extra,
|
||||||
masked_encrypted_extra: db?.masked_encrypted_extra || '',
|
masked_encrypted_extra: db?.masked_encrypted_extra || '',
|
||||||
server_cert: db?.server_cert || undefined,
|
server_cert: db?.server_cert || undefined,
|
||||||
ssh_tunnel: db?.ssh_tunnel || undefined,
|
ssh_tunnel:
|
||||||
|
!isEmpty(db?.ssh_tunnel) && useSSHTunneling
|
||||||
|
? {
|
||||||
|
...db.ssh_tunnel,
|
||||||
|
server_port: Number(db.ssh_tunnel!.server_port),
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
};
|
};
|
||||||
setTestInProgress(true);
|
setTestInProgress(true);
|
||||||
testDatabaseConnection(
|
testDatabaseConnection(
|
||||||
|
@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(
|
||||||
|
type: DBReducerActionType['type'],
|
||||||
|
payload: CustomTextType | DBReducerPayloadType,
|
||||||
|
) => {
|
||||||
|
setDB({ type, payload } as DBReducerActionType);
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleClearValidationErrors = useCallback(() => {
|
||||||
|
setValidationErrors(null);
|
||||||
|
}, [setValidationErrors]);
|
||||||
|
|
||||||
|
const handleParametersChange = useCallback(
|
||||||
|
({ target }: { target: HTMLInputElement }) => {
|
||||||
|
onChange(ActionType.ParametersChange, {
|
||||||
|
type: target.type,
|
||||||
|
name: target.name,
|
||||||
|
checked: target.checked,
|
||||||
|
value: target.value,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[onChange],
|
||||||
|
);
|
||||||
|
|
||||||
const onClose = () => {
|
const onClose = () => {
|
||||||
setDB({ type: ActionType.Reset });
|
setDB({ type: ActionType.Reset });
|
||||||
setHasConnectedDb(false);
|
setHasConnectedDb(false);
|
||||||
setValidationErrors(null); // reset validation errors on close
|
handleClearValidationErrors(); // reset validation errors on close
|
||||||
clearError();
|
clearError();
|
||||||
setEditNewDb(false);
|
setEditNewDb(false);
|
||||||
setFileList([]);
|
setFileList([]);
|
||||||
|
@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
setSSHTunnelPrivateKeys({});
|
setSSHTunnelPrivateKeys({});
|
||||||
setSSHTunnelPrivateKeyPasswords({});
|
setSSHTunnelPrivateKeyPasswords({});
|
||||||
setConfirmedOverwrite(false);
|
setConfirmedOverwrite(false);
|
||||||
setUseSSHTunneling(false);
|
setUseSSHTunneling(undefined);
|
||||||
onHide();
|
onHide();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
setImportingErrorMessage(msg);
|
setImportingErrorMessage(msg);
|
||||||
});
|
});
|
||||||
|
|
||||||
const onChange = (type: any, payload: any) => {
|
|
||||||
setDB({ type, payload } as DBReducerActionType);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onSave = async () => {
|
const onSave = async () => {
|
||||||
let dbConfigExtraExtensionOnSaveError;
|
let dbConfigExtraExtensionOnSaveError;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
dbConfigExtraExtension
|
dbConfigExtraExtension
|
||||||
?.onSave(extraExtensionComponentState, db)
|
?.onSave(extraExtensionComponentState, db)
|
||||||
.then(({ error }: { error: any }) => {
|
.then(({ error }: { error: any }) => {
|
||||||
|
@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
addDangerToast(error);
|
addDangerToast(error);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if (dbConfigExtraExtensionOnSaveError) {
|
if (dbConfigExtraExtensionOnSaveError) {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
return;
|
return;
|
||||||
|
@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// only do validation for non ssh tunnel connections
|
const errors = await getValidation(dbToUpdate, true);
|
||||||
if (!dbToUpdate?.ssh_tunnel) {
|
if (!isEmpty(validationErrors) || errors?.length) {
|
||||||
// make sure that button spinner animates
|
addDangerToast(
|
||||||
setLoading(true);
|
t('Connection failed, please check your connection settings.'),
|
||||||
const errors = await getValidation(dbToUpdate, true);
|
);
|
||||||
if ((validationErrors && !isEmpty(validationErrors)) || errors) {
|
|
||||||
setLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// end spinner animation
|
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const parameters_schema = isEditMode
|
const parameters_schema = isEditMode
|
||||||
|
@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
setLoading(true);
|
// strictly checking for false as an indication that the toggle got unchecked
|
||||||
|
if (useSSHTunneling === false) {
|
||||||
|
// remove ssh tunnel
|
||||||
|
dbToUpdate.ssh_tunnel = null;
|
||||||
|
}
|
||||||
|
|
||||||
if (db?.id) {
|
if (db?.id) {
|
||||||
const result = await updateResource(
|
const result = await updateResource(
|
||||||
db.id as number,
|
db.id as number,
|
||||||
|
@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
}, [sshPrivateKeyPasswordNeeded]);
|
}, [sshPrivateKeyPasswordNeeded]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (db && isSSHTunneling) {
|
if (db?.parameters?.ssh !== undefined) {
|
||||||
setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
|
setUseSSHTunneling(db.parameters.ssh);
|
||||||
}
|
}
|
||||||
}, [db, isSSHTunneling]);
|
}, [db?.parameters?.ssh]);
|
||||||
|
|
||||||
const onDbImport = async (info: UploadChangeParam) => {
|
const onDbImport = async (info: UploadChangeParam) => {
|
||||||
setImportingErrorMessage('');
|
setImportingErrorMessage('');
|
||||||
|
@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
const renderSSHTunnelForm = () => (
|
const renderSSHTunnelForm = () => (
|
||||||
<SSHTunnelForm
|
<SSHTunnelForm
|
||||||
db={db as DatabaseObject}
|
db={db as DatabaseObject}
|
||||||
onSSHTunnelParametersChange={({
|
onSSHTunnelParametersChange={({ target }) => {
|
||||||
target,
|
|
||||||
}: {
|
|
||||||
target: HTMLInputElement | HTMLTextAreaElement;
|
|
||||||
}) =>
|
|
||||||
onChange(ActionType.ParametersSSHTunnelChange, {
|
onChange(ActionType.ParametersSSHTunnelChange, {
|
||||||
type: target.type,
|
type: target.type,
|
||||||
name: target.name,
|
name: target.name,
|
||||||
value: target.value,
|
value: target.value,
|
||||||
})
|
});
|
||||||
}
|
handleClearValidationErrors();
|
||||||
|
}}
|
||||||
setSSHTunnelLoginMethod={(method: AuthType) =>
|
setSSHTunnelLoginMethod={(method: AuthType) =>
|
||||||
setDB({
|
setDB({
|
||||||
type: ActionType.SetSSHTunnelLoginMethod,
|
type: ActionType.SetSSHTunnelLoginMethod,
|
||||||
|
@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
payload: { indexToDelete: idx },
|
payload: { indexToDelete: idx },
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
onParametersChange={({ target }: { target: HTMLInputElement }) =>
|
onParametersChange={handleParametersChange}
|
||||||
onChange(ActionType.ParametersChange, {
|
|
||||||
type: target.type,
|
|
||||||
name: target.name,
|
|
||||||
checked: target.checked,
|
|
||||||
value: target.value,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
onChange={({ target }: { target: HTMLInputElement }) =>
|
onChange={({ target }: { target: HTMLInputElement }) =>
|
||||||
onChange(ActionType.TextChange, {
|
onChange(ActionType.TextChange, {
|
||||||
name: target.name,
|
name: target.name,
|
||||||
|
@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
getValidation={() => getValidation(db)}
|
getValidation={() => getValidation(db)}
|
||||||
validationErrors={validationErrors}
|
validationErrors={validationErrors}
|
||||||
getPlaceholder={getPlaceholder}
|
getPlaceholder={getPlaceholder}
|
||||||
clearValidationErrors={() => setValidationErrors(null)}
|
clearValidationErrors={handleClearValidationErrors}
|
||||||
/>
|
/>
|
||||||
{db?.parameters?.ssh && (
|
{useSSHTunneling && (
|
||||||
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
|
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
|
@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||||
testInProgress={testInProgress}
|
testInProgress={testInProgress}
|
||||||
>
|
>
|
||||||
<SSHTunnelSwitchComponent
|
<SSHTunnelSwitchComponent
|
||||||
isEditMode={isEditMode}
|
dbModel={dbModel}
|
||||||
dbFetched={dbFetched}
|
db={db as DatabaseObject}
|
||||||
disableSSHTunnelingForEngine={disableSSHTunnelingForEngine}
|
changeMethods={{
|
||||||
useSSHTunneling={useSSHTunneling}
|
onParametersChange: handleParametersChange,
|
||||||
setUseSSHTunneling={setUseSSHTunneling}
|
}}
|
||||||
setDB={setDB}
|
clearValidationErrors={handleClearValidationErrors}
|
||||||
isSSHTunneling={isSSHTunneling}
|
|
||||||
/>
|
/>
|
||||||
{useSSHTunneling && renderSSHTunnelForm()}
|
{useSSHTunneling && renderSSHTunnelForm()}
|
||||||
</SqlAlchemyForm>
|
</SqlAlchemyForm>
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
import { JsonObject } from '@superset-ui/core';
|
||||||
|
import { InputProps } from 'antd/lib/input';
|
||||||
|
import { ChangeEvent, EventHandler, FormEvent } from 'react';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
* or more contributor license agreements. See the NOTICE file
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
@ -108,7 +112,7 @@ export type DatabaseObject = {
|
||||||
};
|
};
|
||||||
|
|
||||||
// SSH Tunnel information
|
// SSH Tunnel information
|
||||||
ssh_tunnel?: SSHTunnelObject;
|
ssh_tunnel?: SSHTunnelObject | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type DatabaseForm = {
|
export type DatabaseForm = {
|
||||||
|
@ -195,6 +199,10 @@ export type DatabaseForm = {
|
||||||
};
|
};
|
||||||
preferred: boolean;
|
preferred: boolean;
|
||||||
sqlalchemy_uri_placeholder: string;
|
sqlalchemy_uri_placeholder: string;
|
||||||
|
engine_information: {
|
||||||
|
supports_file_upload: boolean;
|
||||||
|
disable_ssh_tunneling: boolean;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// the values should align with the database
|
// the values should align with the database
|
||||||
|
@ -231,3 +239,73 @@ export interface ExtraJson {
|
||||||
};
|
};
|
||||||
version?: string;
|
version?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type CustomTextType = {
|
||||||
|
value?: string | boolean | number;
|
||||||
|
type?: string | null;
|
||||||
|
name?: string;
|
||||||
|
checked?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
type CustomHTMLInputElement = Omit<Partial<CustomTextType>, 'value' | 'type'> &
|
||||||
|
CustomTextType;
|
||||||
|
|
||||||
|
type CustomHTMLTextAreaElement = Omit<
|
||||||
|
Partial<CustomTextType>,
|
||||||
|
'value' | 'type'
|
||||||
|
> &
|
||||||
|
CustomTextType;
|
||||||
|
|
||||||
|
export type CustomParametersChangeType<T = CustomTextType> =
|
||||||
|
| FormEvent<InputProps>
|
||||||
|
| { target: T };
|
||||||
|
|
||||||
|
export type CustomEventHandlerType = EventHandler<
|
||||||
|
ChangeEvent<CustomHTMLInputElement | CustomHTMLTextAreaElement>
|
||||||
|
>;
|
||||||
|
|
||||||
|
export interface FieldPropTypes {
|
||||||
|
required: boolean;
|
||||||
|
hasTooltip?: boolean;
|
||||||
|
tooltipText?: (value: any) => string;
|
||||||
|
placeholder?: string;
|
||||||
|
onParametersChange: (event: CustomParametersChangeType) => void;
|
||||||
|
onParametersUploadFileChange: (value: any) => string;
|
||||||
|
changeMethods: {
|
||||||
|
onParametersChange: (event: CustomParametersChangeType) => void;
|
||||||
|
} & {
|
||||||
|
onChange: (value: any) => string;
|
||||||
|
} & {
|
||||||
|
onQueryChange: (value: any) => string;
|
||||||
|
} & { onParametersUploadFileChange: (value: any) => string } & {
|
||||||
|
onAddTableCatalog: () => void;
|
||||||
|
onRemoveTableCatalog: (idx: number) => void;
|
||||||
|
} & {
|
||||||
|
onExtraInputChange: (value: any) => void;
|
||||||
|
onSSHTunnelParametersChange: CustomEventHandlerType;
|
||||||
|
};
|
||||||
|
validationErrors: JsonObject | null;
|
||||||
|
getValidation: () => void;
|
||||||
|
clearValidationErrors: () => void;
|
||||||
|
db?: DatabaseObject;
|
||||||
|
dbModel?: DatabaseForm;
|
||||||
|
field: string;
|
||||||
|
isEditMode?: boolean;
|
||||||
|
sslForced?: boolean;
|
||||||
|
defaultDBName?: string;
|
||||||
|
editNewDb?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChangeMethodsType = FieldPropTypes['changeMethods'];
|
||||||
|
|
||||||
|
// changeMethods compatibility with dynamic forms
|
||||||
|
type SwitchPropsChangeMethodsType = {
|
||||||
|
onParametersChange: ChangeMethodsType['onParametersChange'];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type SwitchProps = {
|
||||||
|
dbModel: DatabaseForm;
|
||||||
|
db: DatabaseObject;
|
||||||
|
changeMethods: SwitchPropsChangeMethodsType;
|
||||||
|
clearValidationErrors: () => void;
|
||||||
|
};
|
||||||
|
|
|
@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart';
|
||||||
import copyTextToClipboard from 'src/utils/copy';
|
import copyTextToClipboard from 'src/utils/copy';
|
||||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||||
import SupersetText from 'src/utils/textUtils';
|
import SupersetText from 'src/utils/textUtils';
|
||||||
import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types';
|
import { DatabaseObject } from 'src/features/databases/types';
|
||||||
|
import { FavoriteStatus, ImportResourceName } from './types';
|
||||||
|
|
||||||
interface ListViewResourceState<D extends object = any> {
|
interface ListViewResourceState<D extends object = any> {
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
|
@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () =>
|
||||||
SupersetText.DB_CONNECTION_DOC_LINKS;
|
SupersetText.DB_CONNECTION_DOC_LINKS;
|
||||||
|
|
||||||
export const testDatabaseConnection = (
|
export const testDatabaseConnection = (
|
||||||
connection: DatabaseObject,
|
connection: Partial<DatabaseObject>,
|
||||||
handleErrorMsg: (errorMsg: string) => void,
|
handleErrorMsg: (errorMsg: string) => void,
|
||||||
addSuccessToast: (arg0: string) => void,
|
addSuccessToast: (arg0: string) => void,
|
||||||
) => {
|
) => {
|
||||||
|
@ -745,7 +746,7 @@ export function useDatabaseValidation() {
|
||||||
const getValidation = useCallback(
|
const getValidation = useCallback(
|
||||||
(database: Partial<DatabaseObject> | null, onCreate = false) => {
|
(database: Partial<DatabaseObject> | null, onCreate = false) => {
|
||||||
if (database?.parameters?.ssh) {
|
if (database?.parameters?.ssh) {
|
||||||
// when ssh tunnel is enabled we don't want to render any validation errors
|
// TODO: /validate_parameters/ and related utils should support ssh tunnel
|
||||||
setValidationErrors(null);
|
setValidationErrors(null);
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
from flask_babel import gettext as _
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
from superset import is_feature_enabled
|
from superset import is_feature_enabled
|
||||||
|
@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
|
||||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
SSHTunnelCreateFailedError,
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
SSHTunnelingNotEnabledError,
|
SSHTunnelingNotEnabledError,
|
||||||
SSHTunnelInvalidError,
|
SSHTunnelInvalidError,
|
||||||
)
|
)
|
||||||
|
@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand):
|
||||||
try:
|
try:
|
||||||
# Test connection before starting create transaction
|
# Test connection before starting create transaction
|
||||||
TestConnectionDatabaseCommand(self._properties).run()
|
TestConnectionDatabaseCommand(self._properties).run()
|
||||||
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
|
except (
|
||||||
|
SupersetErrorsException,
|
||||||
|
SSHTunnelingNotEnabledError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
) as ex:
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||||
|
@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand):
|
||||||
SSHTunnelInvalidError,
|
SSHTunnelInvalidError,
|
||||||
SSHTunnelCreateFailedError,
|
SSHTunnelCreateFailedError,
|
||||||
SSHTunnelingNotEnabledError,
|
SSHTunnelingNotEnabledError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
) as ex:
|
) as ex:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
|
@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand):
|
||||||
# Check database_name uniqueness
|
# Check database_name uniqueness
|
||||||
if not DatabaseDAO.validate_uniqueness(database_name):
|
if not DatabaseDAO.validate_uniqueness(database_name):
|
||||||
exceptions.append(DatabaseExistsValidationError())
|
exceptions.append(DatabaseExistsValidationError())
|
||||||
|
|
||||||
if exceptions:
|
if exceptions:
|
||||||
exception = DatabaseInvalidError()
|
exception = DatabaseInvalidError()
|
||||||
exception.extend(exceptions)
|
exception.extend(exceptions)
|
||||||
|
|
|
@ -23,11 +23,13 @@ from marshmallow import ValidationError
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
SSHTunnelCreateFailedError,
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
SSHTunnelInvalidError,
|
SSHTunnelInvalidError,
|
||||||
SSHTunnelRequiredFieldValidationError,
|
SSHTunnelRequiredFieldValidationError,
|
||||||
)
|
)
|
||||||
from superset.daos.database import SSHTunnelDAO
|
from superset.daos.database import SSHTunnelDAO
|
||||||
from superset.daos.exceptions import DAOCreateFailedError
|
from superset.daos.exceptions import DAOCreateFailedError
|
||||||
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.extensions import event_logger
|
from superset.extensions import event_logger
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CreateSSHTunnelCommand(BaseCommand):
|
class CreateSSHTunnelCommand(BaseCommand):
|
||||||
|
_database: Database
|
||||||
|
|
||||||
def __init__(self, database: Database, data: dict[str, Any]):
|
def __init__(self, database: Database, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._properties["database"] = database
|
self._properties["database"] = database
|
||||||
|
self._database = database
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
try:
|
try:
|
||||||
|
@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand):
|
||||||
server_address: Optional[str] = self._properties.get("server_address")
|
server_address: Optional[str] = self._properties.get("server_address")
|
||||||
server_port: Optional[int] = self._properties.get("server_port")
|
server_port: Optional[int] = self._properties.get("server_port")
|
||||||
username: Optional[str] = self._properties.get("username")
|
username: Optional[str] = self._properties.get("username")
|
||||||
|
password: Optional[str] = self._properties.get("password")
|
||||||
private_key: Optional[str] = self._properties.get("private_key")
|
private_key: Optional[str] = self._properties.get("private_key")
|
||||||
private_key_password: Optional[str] = self._properties.get(
|
private_key_password: Optional[str] = self._properties.get(
|
||||||
"private_key_password"
|
"private_key_password"
|
||||||
)
|
)
|
||||||
|
url = make_url_safe(self._database.sqlalchemy_uri)
|
||||||
|
if not url.port:
|
||||||
|
raise SSHTunnelDatabasePortError()
|
||||||
if not server_address:
|
if not server_address:
|
||||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
||||||
if not server_port:
|
if not server_port:
|
||||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
||||||
if not username:
|
if not username:
|
||||||
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
||||||
|
if not private_key and not password:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("password"))
|
||||||
if private_key_password and private_key is None:
|
if private_key_password and private_key is None:
|
||||||
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||||
if exceptions:
|
if exceptions:
|
||||||
|
|
|
@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
|
||||||
message = _("SSH Tunnel parameters are invalid.")
|
message = _("SSH Tunnel parameters are invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelDatabasePortError(CommandInvalidError):
|
||||||
|
message = _("A database port is required when connecting via SSH Tunnel.")
|
||||||
|
|
||||||
|
|
||||||
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
||||||
message = _("SSH Tunnel could not be updated.")
|
message = _("SSH Tunnel could not be updated.")
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
SSHTunnelInvalidError,
|
SSHTunnelInvalidError,
|
||||||
SSHTunnelNotFoundError,
|
SSHTunnelNotFoundError,
|
||||||
SSHTunnelRequiredFieldValidationError,
|
SSHTunnelRequiredFieldValidationError,
|
||||||
|
@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
from superset.daos.database import SSHTunnelDAO
|
from superset.daos.database import SSHTunnelDAO
|
||||||
from superset.daos.exceptions import DAOUpdateFailedError
|
from superset.daos.exceptions import DAOUpdateFailedError
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.databases.utils import make_url_safe
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[SSHTunnel] = None
|
self._model: Optional[SSHTunnel] = None
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Optional[Model]:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
if self._model is not None: # So we dont get incompatible types error
|
if self._model is None:
|
||||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
return None
|
||||||
|
|
||||||
|
# unset password if private key is provided
|
||||||
|
if self._properties.get("private_key"):
|
||||||
|
self._properties["password"] = None
|
||||||
|
|
||||||
|
# unset private key and password if password is provided
|
||||||
|
if self._properties.get("password"):
|
||||||
|
self._properties["private_key"] = None
|
||||||
|
self._properties["private_key_password"] = None
|
||||||
|
|
||||||
|
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||||
|
return tunnel
|
||||||
except DAOUpdateFailedError as ex:
|
except DAOUpdateFailedError as ex:
|
||||||
raise SSHTunnelUpdateFailedError() from ex
|
raise SSHTunnelUpdateFailedError() from ex
|
||||||
return tunnel
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# Validate/populate model exists
|
# Validate/populate model exists
|
||||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||||
if not self._model:
|
if not self._model:
|
||||||
raise SSHTunnelNotFoundError()
|
raise SSHTunnelNotFoundError()
|
||||||
|
|
||||||
|
url = make_url_safe(self._model.database.sqlalchemy_uri)
|
||||||
private_key: Optional[str] = self._properties.get("private_key")
|
private_key: Optional[str] = self._properties.get("private_key")
|
||||||
private_key_password: Optional[str] = self._properties.get(
|
private_key_password: Optional[str] = self._properties.get(
|
||||||
"private_key_password"
|
"private_key_password"
|
||||||
|
@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
|
||||||
raise SSHTunnelInvalidError(
|
raise SSHTunnelInvalidError(
|
||||||
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
|
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
|
||||||
)
|
)
|
||||||
|
if not url.port:
|
||||||
|
raise SSHTunnelDatabasePortError()
|
||||||
|
|
|
@ -32,7 +32,10 @@ from superset.commands.database.exceptions import (
|
||||||
DatabaseTestConnectionDriverError,
|
DatabaseTestConnectionDriverError,
|
||||||
DatabaseTestConnectionUnexpectedError,
|
DatabaseTestConnectionUnexpectedError,
|
||||||
)
|
)
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
SSHTunnelingNotEnabledError,
|
||||||
|
)
|
||||||
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
|
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
|
@ -61,20 +64,22 @@ def get_log_connection_action(
|
||||||
|
|
||||||
|
|
||||||
class TestConnectionDatabaseCommand(BaseCommand):
|
class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
|
_model: Optional[Database] = None
|
||||||
|
_context: dict[str, Any]
|
||||||
|
_uri: str
|
||||||
|
|
||||||
def __init__(self, data: dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[Database] = None
|
|
||||||
|
|
||||||
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
|
if (database_name := self._properties.get("database_name")) is not None:
|
||||||
self.validate()
|
self._model = DatabaseDAO.get_database_by_name(database_name)
|
||||||
ex_str = ""
|
|
||||||
uri = self._properties.get("sqlalchemy_uri", "")
|
uri = self._properties.get("sqlalchemy_uri", "")
|
||||||
if self._model and uri == self._model.safe_sqlalchemy_uri():
|
if self._model and uri == self._model.safe_sqlalchemy_uri():
|
||||||
uri = self._model.sqlalchemy_uri_decrypted
|
uri = self._model.sqlalchemy_uri_decrypted
|
||||||
ssh_tunnel = self._properties.get("ssh_tunnel")
|
|
||||||
|
|
||||||
# context for error messages
|
|
||||||
url = make_url_safe(uri)
|
url = make_url_safe(uri)
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
"hostname": url.host,
|
"hostname": url.host,
|
||||||
"password": url.password,
|
"password": url.password,
|
||||||
|
@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
"database": url.database,
|
"database": url.database,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._context = context
|
||||||
|
self._uri = uri
|
||||||
|
|
||||||
|
def run(self) -> None: # pylint: disable=too-many-statements
|
||||||
|
self.validate()
|
||||||
|
ex_str = ""
|
||||||
|
ssh_tunnel = self._properties.get("ssh_tunnel")
|
||||||
|
|
||||||
serialized_encrypted_extra = self._properties.get(
|
serialized_encrypted_extra = self._properties.get(
|
||||||
"masked_encrypted_extra",
|
"masked_encrypted_extra",
|
||||||
"{}",
|
"{}",
|
||||||
|
@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
encrypted_extra=serialized_encrypted_extra,
|
encrypted_extra=serialized_encrypted_extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
database.set_sqlalchemy_uri(uri)
|
database.set_sqlalchemy_uri(self._uri)
|
||||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||||
|
|
||||||
# Generate tunnel if present in the properties
|
# Generate tunnel if present in the properties
|
||||||
if ssh_tunnel:
|
if ssh_tunnel:
|
||||||
if not is_feature_enabled("SSH_TUNNELING"):
|
# unmask password while allowing for updated values
|
||||||
raise SSHTunnelingNotEnabledError()
|
|
||||||
# If there's an existing tunnel for that DB we need to use the stored
|
|
||||||
# password, private_key and private_key_password instead
|
|
||||||
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
|
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
|
||||||
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
|
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
|
||||||
ssh_tunnel = unmask_password_info(
|
ssh_tunnel = unmask_password_info(
|
||||||
|
@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
engine=database.db_engine_spec.__name__,
|
engine=database.db_engine_spec.__name__,
|
||||||
)
|
)
|
||||||
# check for custom errors (wrong username, wrong password, etc)
|
# check for custom errors (wrong username, wrong password, etc)
|
||||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||||
raise SupersetErrorsException(errors) from ex
|
raise SupersetErrorsException(errors) from ex
|
||||||
except SupersetSecurityException as ex:
|
except SupersetSecurityException as ex:
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
|
@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
),
|
),
|
||||||
engine=database.db_engine_spec.__name__,
|
engine=database.db_engine_spec.__name__,
|
||||||
)
|
)
|
||||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||||
raise DatabaseTestConnectionUnexpectedError(errors) from ex
|
raise DatabaseTestConnectionUnexpectedError(errors) from ex
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
if (database_name := self._properties.get("database_name")) is not None:
|
if self._properties.get("ssh_tunnel"):
|
||||||
self._model = DatabaseDAO.get_database_by_name(database_name)
|
if not is_feature_enabled("SSH_TUNNELING"):
|
||||||
|
raise SSHTunnelingNotEnabledError()
|
||||||
|
if not self._context.get("port"):
|
||||||
|
raise SSHTunnelDatabasePortError()
|
||||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_appbuilder.models.sqla import Model
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
from flask_babel import gettext as _
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
from superset import is_feature_enabled
|
from superset import is_feature_enabled
|
||||||
|
@ -30,8 +31,11 @@ from superset.commands.database.exceptions import (
|
||||||
DatabaseUpdateFailedError,
|
DatabaseUpdateFailedError,
|
||||||
)
|
)
|
||||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||||
|
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
SSHTunnelCreateFailedError,
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
SSHTunnelDeleteFailedError,
|
||||||
SSHTunnelingNotEnabledError,
|
SSHTunnelingNotEnabledError,
|
||||||
SSHTunnelInvalidError,
|
SSHTunnelInvalidError,
|
||||||
SSHTunnelUpdateFailedError,
|
SSHTunnelUpdateFailedError,
|
||||||
|
@ -47,15 +51,21 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpdateDatabaseCommand(BaseCommand):
|
class UpdateDatabaseCommand(BaseCommand):
|
||||||
|
_model: Optional[Database]
|
||||||
|
|
||||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[Database] = None
|
self._model: Optional[Database] = None
|
||||||
|
|
||||||
def run(self) -> Model:
|
def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches
|
||||||
self.validate()
|
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||||
|
|
||||||
if not self._model:
|
if not self._model:
|
||||||
raise DatabaseNotFoundError()
|
raise DatabaseNotFoundError()
|
||||||
|
|
||||||
|
self.validate()
|
||||||
|
|
||||||
old_database_name = self._model.database_name
|
old_database_name = self._model.database_name
|
||||||
|
|
||||||
# unmask ``encrypted_extra``
|
# unmask ``encrypted_extra``
|
||||||
|
@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
database = DatabaseDAO.update(self._model, self._properties, commit=False)
|
database = DatabaseDAO.update(self._model, self._properties, commit=False)
|
||||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||||
|
|
||||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||||
|
|
||||||
|
if "ssh_tunnel" in self._properties:
|
||||||
if not is_feature_enabled("SSH_TUNNELING"):
|
if not is_feature_enabled("SSH_TUNNELING"):
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise SSHTunnelingNotEnabledError()
|
raise SSHTunnelingNotEnabledError()
|
||||||
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
|
||||||
if existing_ssh_tunnel_model is None:
|
if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
|
||||||
# We couldn't found an existing tunnel so we need to create one
|
# We need to remove the existing tunnel
|
||||||
try:
|
try:
|
||||||
CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
|
DeleteSSHTunnelCommand(ssh_tunnel.id).run()
|
||||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
ssh_tunnel = None
|
||||||
# So we can show the original message
|
except SSHTunnelDeleteFailedError as ex:
|
||||||
raise ex
|
|
||||||
except Exception as ex:
|
|
||||||
raise DatabaseUpdateFailedError() from ex
|
|
||||||
else:
|
|
||||||
# We found an existing tunnel so we need to update it
|
|
||||||
try:
|
|
||||||
UpdateSSHTunnelCommand(
|
|
||||||
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
|
||||||
).run()
|
|
||||||
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
|
|
||||||
# So we can show the original message
|
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise DatabaseUpdateFailedError() from ex
|
raise DatabaseUpdateFailedError() from ex
|
||||||
|
|
||||||
|
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||||
|
if ssh_tunnel is None:
|
||||||
|
# We couldn't found an existing tunnel so we need to create one
|
||||||
|
try:
|
||||||
|
ssh_tunnel = CreateSSHTunnelCommand(
|
||||||
|
database, ssh_tunnel_properties
|
||||||
|
).run()
|
||||||
|
except (
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
) as ex:
|
||||||
|
# So we can show the original message
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
raise DatabaseUpdateFailedError() from ex
|
||||||
|
else:
|
||||||
|
# We found an existing tunnel so we need to update it
|
||||||
|
try:
|
||||||
|
ssh_tunnel_id = ssh_tunnel.id
|
||||||
|
ssh_tunnel = UpdateSSHTunnelCommand(
|
||||||
|
ssh_tunnel_id, ssh_tunnel_properties
|
||||||
|
).run()
|
||||||
|
except (
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
SSHTunnelUpdateFailedError,
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
) as ex:
|
||||||
|
# So we can show the original message
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
raise DatabaseUpdateFailedError() from ex
|
||||||
|
|
||||||
# adding a new database we always want to force refresh schema list
|
# adding a new database we always want to force refresh schema list
|
||||||
# TODO Improve this simplistic implementation for catching DB conn fails
|
# TODO Improve this simplistic implementation for catching DB conn fails
|
||||||
try:
|
try:
|
||||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
|
||||||
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
|
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
|
@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
exceptions: list[ValidationError] = []
|
exceptions: list[ValidationError] = []
|
||||||
# Validate/populate model exists
|
|
||||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
|
||||||
if not self._model:
|
|
||||||
raise DatabaseNotFoundError()
|
|
||||||
database_name: Optional[str] = self._properties.get("database_name")
|
database_name: Optional[str] = self._properties.get("database_name")
|
||||||
if database_name:
|
if database_name:
|
||||||
# Check database_name uniqueness
|
# Check database_name uniqueness
|
||||||
|
|
|
@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
|
||||||
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
|
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
|
||||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
SSHTunnelDeleteFailedError,
|
SSHTunnelDeleteFailedError,
|
||||||
SSHTunnelingNotEnabledError,
|
SSHTunnelingNotEnabledError,
|
||||||
)
|
)
|
||||||
|
@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return self.response_422(message=str(ex))
|
return self.response_422(message=str(ex))
|
||||||
except SSHTunnelingNotEnabledError as ex:
|
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||||
return self.response_400(message=str(ex))
|
return self.response_400(message=str(ex))
|
||||||
except SupersetException as ex:
|
except SupersetException as ex:
|
||||||
return self.response(ex.status, message=ex.message)
|
return self.response(ex.status, message=ex.message)
|
||||||
|
@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return self.response_422(message=str(ex))
|
return self.response_422(message=str(ex))
|
||||||
except SSHTunnelingNotEnabledError as ex:
|
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||||
return self.response_400(message=str(ex))
|
return self.response_400(message=str(ex))
|
||||||
|
|
||||||
@expose("/<int:pk>", methods=("DELETE",))
|
@expose("/<int:pk>", methods=("DELETE",))
|
||||||
|
@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
try:
|
try:
|
||||||
TestConnectionDatabaseCommand(item).run()
|
TestConnectionDatabaseCommand(item).run()
|
||||||
return self.response(200, message="OK")
|
return self.response(200, message="OK")
|
||||||
except SSHTunnelingNotEnabledError as ex:
|
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||||
return self.response_400(message=str(ex))
|
return self.response_400(message=str(ex))
|
||||||
|
|
||||||
@expose("/<int:pk>/related_objects/", methods=("GET",))
|
@expose("/<int:pk>/related_objects/", methods=("GET",))
|
||||||
|
|
|
@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
|
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
|
@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
|
||||||
db.session.delete(model)
|
db.session.delete(model)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_create_database_with_missing_port_raises_error(
|
||||||
|
self,
|
||||||
|
mock_test_connection_database_command_run,
|
||||||
|
mock_create_is_feature_enabled,
|
||||||
|
mock_get_all_schema_names,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||||
|
"""
|
||||||
|
mock_create_is_feature_enabled.return_value = True
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
|
||||||
|
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
|
||||||
|
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 400)
|
||||||
|
self.assertEqual(
|
||||||
|
response.get("message"),
|
||||||
|
"A database port is required when connecting via SSH Tunnel.",
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
)
|
)
|
||||||
|
@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
|
||||||
db.session.delete(model)
|
db.session.delete(model)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||||
|
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_update_database_with_missing_port_raises_error(
|
||||||
|
self,
|
||||||
|
mock_test_connection_database_command_run,
|
||||||
|
mock_create_is_feature_enabled,
|
||||||
|
mock_update_is_feature_enabled,
|
||||||
|
mock_get_all_schema_names,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||||
|
"""
|
||||||
|
mock_create_is_feature_enabled.return_value = True
|
||||||
|
mock_update_is_feature_enabled.return_value = True
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
|
||||||
|
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
|
||||||
|
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": modified_sqlalchemy_uri,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response_create = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
|
||||||
|
uri = "api/v1/database/{}".format(response_create.get("id"))
|
||||||
|
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 400)
|
||||||
|
self.assertEqual(
|
||||||
|
response.get("message"),
|
||||||
|
"A database port is required when connecting via SSH Tunnel.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response_create.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||||
|
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||||
|
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_delete_ssh_tunnel(
|
||||||
|
self,
|
||||||
|
mock_test_connection_database_command_run,
|
||||||
|
mock_create_is_feature_enabled,
|
||||||
|
mock_update_is_feature_enabled,
|
||||||
|
mock_delete_is_feature_enabled,
|
||||||
|
mock_get_all_schema_names,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test deleting a SSH tunnel via Database update
|
||||||
|
"""
|
||||||
|
mock_create_is_feature_enabled.return_value = True
|
||||||
|
mock_update_is_feature_enabled.return_value = True
|
||||||
|
mock_delete_is_feature_enabled.return_value = True
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
}
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
|
||||||
|
uri = "api/v1/database/{}".format(response.get("id"))
|
||||||
|
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||||
|
response_update = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||||
|
|
||||||
|
database_data_with_ssh_tunnel_null = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
|
||||||
|
response_update = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_ssh_tunnel is None
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_ssh_tunnel_command() -> None:
|
def test_create_ssh_tunnel_command() -> None:
|
||||||
|
@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
database = Database(
|
||||||
|
id=1,
|
||||||
|
database_name="my_database",
|
||||||
|
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||||
|
)
|
||||||
|
|
||||||
properties = {
|
properties = {
|
||||||
"database_id": database.id,
|
"database_id": database.id,
|
||||||
|
@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
database = Database(
|
||||||
|
id=1,
|
||||||
|
database_name="my_database",
|
||||||
|
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||||
|
)
|
||||||
|
|
||||||
# If we are trying to create a tunnel with a private_key_password
|
# If we are trying to create a tunnel with a private_key_password
|
||||||
# then a private_key is mandatory
|
# then a private_key is mandatory
|
||||||
|
@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||||
command.run()
|
command.run()
|
||||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_ssh_tunnel_command_no_port() -> None:
|
||||||
|
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
database = Database(
|
||||||
|
id=1,
|
||||||
|
database_name="my_database",
|
||||||
|
sqlalchemy_uri="postgresql://u:p@localhost/db",
|
||||||
|
)
|
||||||
|
|
||||||
|
properties = {
|
||||||
|
"database": database,
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": "3005",
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
command = CreateSSHTunnelCommand(database, properties)
|
||||||
|
|
||||||
|
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"A database port is required when connecting via SSH Tunnel."
|
||||||
|
)
|
||||||
|
|
|
@ -20,11 +20,14 @@ from collections.abc import Iterator
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm.session import Session
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
|
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||||
|
SSHTunnelDatabasePortError,
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session_with_data(session: Session) -> Iterator[Session]:
|
def session_with_data(request, session: Session) -> Iterator[Session]:
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||||
engine = session.get_bind()
|
engine = session.get_bind()
|
||||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||||
|
|
||||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
|
||||||
|
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||||
sqla_table = SqlaTable(
|
sqla_table = SqlaTable(
|
||||||
table_name="my_sqla_table",
|
table_name="my_sqla_table",
|
||||||
columns=[],
|
columns=[],
|
||||||
|
@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
||||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||||
command.run()
|
command.run()
|
||||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
|
||||||
|
)
|
||||||
|
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
|
||||||
|
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||||
|
from superset.daos.database import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert 1 == result.database_id
|
||||||
|
assert "Test" == result.server_address
|
||||||
|
|
||||||
|
update_payload = {"server_address": "Test update"}
|
||||||
|
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||||
|
|
||||||
|
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"A database port is required when connecting via SSH Tunnel."
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue