From 89e89de341c555a1fdbe9d3f5bccada58eb08059 Mon Sep 17 00:00:00 2001 From: Geido <60598000+geido@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:56:54 +0100 Subject: [PATCH] fix: SSH Tunnel configuration settings (#27186) --- .../src/ui-overrides/types.ts | 18 +- .../features/alerts/AlertReportModal.test.tsx | 4 +- .../CommonParameters.tsx | 35 +-- .../DatabaseConnectionForm/EncryptedField.tsx | 2 +- .../DatabaseConnectionForm/TableCatalog.tsx | 3 +- .../ValidatedInputField.tsx | 2 +- .../DatabaseConnectionForm/index.tsx | 130 +++++------ .../databases/DatabaseModal/SSHTunnelForm.tsx | 12 +- .../DatabaseModal/SSHTunnelSwitch.test.tsx | 162 ++++++++++++++ .../DatabaseModal/SSHTunnelSwitch.tsx | 82 +++++-- .../databases/DatabaseModal/index.test.tsx | 11 +- .../databases/DatabaseModal/index.tsx | 132 +++++++----- .../src/features/databases/types.ts | 80 ++++++- superset-frontend/src/views/CRUD/hooks.ts | 7 +- superset/commands/database/create.py | 10 +- .../commands/database/ssh_tunnel/create.py | 11 + .../database/ssh_tunnel/exceptions.py | 4 + .../commands/database/ssh_tunnel/update.py | 25 ++- superset/commands/database/test_connection.py | 45 ++-- superset/commands/database/update.py | 79 ++++--- superset/databases/api.py | 7 +- .../integration_tests/databases/api_tests.py | 201 ++++++++++++++++++ .../ssh_tunnel/commands/create_test.py | 45 +++- .../ssh_tunnel/commands/update_test.py | 35 ++- 24 files changed, 871 insertions(+), 271 deletions(-) create mode 100644 superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx diff --git a/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts b/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts index 45ec06e90e..60598bd4e1 100644 --- a/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts +++ b/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts @@ -44,15 +44,15 @@ interface MenuObjectChildProps { disable?: boolean; } -export interface SwitchProps { - isEditMode: boolean; - dbFetched: any; - disableSSHTunnelingForEngine?: boolean; - useSSHTunneling: boolean; - setUseSSHTunneling: React.Dispatch>; - setDB: React.Dispatch; - isSSHTunneling: boolean; -} +// loose typing to avoid any circular dependencies +// refer to SSHTunnelSwitch component for strict typing +type SwitchProps = { + db: object; + changeMethods: { + onParametersChange: (event: any) => void; + }; + clearValidationErrors: () => void; +}; type ConfigDetailsProps = { embeddedId: string; diff --git a/superset-frontend/src/features/alerts/AlertReportModal.test.tsx b/superset-frontend/src/features/alerts/AlertReportModal.test.tsx index ee9504286d..358aa27df3 100644 --- a/superset-frontend/src/features/alerts/AlertReportModal.test.tsx +++ b/superset-frontend/src/features/alerts/AlertReportModal.test.tsx @@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => { useRedux: true, }); userEvent.click(screen.getByTestId('schedule-panel')); - const days = screen.getAllByTitle(/day/i, { exact: true }); - expect(days.length).toBe(2); + const day = screen.getByText('day'); + expect(day).toBeInTheDocument(); }); // Notification Method Section diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx index 7b52eab26c..3f1f5f9625 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx @@ -17,12 +17,11 @@ * under the License. */ import React from 'react'; -import { isEmpty } from 'lodash'; import { SupersetTheme, t } from '@superset-ui/core'; import { AntdSwitch } from 'src/components'; import InfoTooltip from 'src/components/InfoTooltip'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; -import { FieldPropTypes } from '.'; +import { FieldPropTypes } from '../../types'; import { toggleStyle, infoTooltip } from '../styles'; export const hostField = ({ @@ -252,35 +251,3 @@ export const forceSSLField = ({ /> ); - -export const SSHTunnelSwitch = ({ - isEditMode, - changeMethods, - clearValidationErrors, - db, -}: FieldPropTypes) => ( -
infoTooltip(theme)}> - { - changeMethods.onParametersChange({ - target: { - type: 'toggle', - name: 'ssh', - checked: true, - value: changed, - }, - }); - clearValidationErrors(); - }} - data-test="ssh-tunnel-switch" - /> - {t('SSH Tunnel')} - -
-); diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx index c5e268e569..009afc84ef 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx @@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components'; import InfoTooltip from 'src/components/InfoTooltip'; import FormLabel from 'src/components/Form/FormLabel'; import Icons from 'src/components/Icons'; -import { FieldPropTypes } from '.'; +import { FieldPropTypes } from '../../types'; import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles'; enum CredentialInfoOptions { diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx index ed5cc94903..47a0ec1579 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx @@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; import FormLabel from 'src/components/Form/FormLabel'; import Icons from 'src/components/Icons'; -import { FieldPropTypes } from '.'; import { StyledFooterButton, StyledCatalogTable } from '../styles'; -import { CatalogObject } from '../../types'; +import { CatalogObject, FieldPropTypes } from '../../types'; export const TableCatalog = ({ required, diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx index ec2e239ac4..d6794f9a21 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx @@ -19,7 +19,7 @@ import React from 'react'; import { t } from '@superset-ui/core'; import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput'; -import { FieldPropTypes } from '.'; +import { FieldPropTypes } from '../../types'; const FIELD_TEXT_MAP = { account: { diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx index e747b3c895..fc076b624f 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx @@ -17,7 +17,11 @@ * under the License. */ import React, { FormEvent } from 'react'; -import { SupersetTheme, JsonObject } from '@superset-ui/core'; +import { + SupersetTheme, + JsonObject, + getExtensionsRegistry, +} from '@superset-ui/core'; import { InputProps } from 'antd/lib/input'; import { Form } from 'src/components/Form'; import { @@ -31,13 +35,13 @@ import { portField, queryField, usernameField, - SSHTunnelSwitch, } from './CommonParameters'; import { validatedInputField } from './ValidatedInputField'; import { EncryptedField } from './EncryptedField'; import { TableCatalog } from './TableCatalog'; import { formScrollableStyles, validatedFormStyles } from '../styles'; import { DatabaseForm, DatabaseObject } from '../../types'; +import SSHTunnelSwitch from '../SSHTunnelSwitch'; export const FormFieldOrder = [ 'host', @@ -59,34 +63,10 @@ export const FormFieldOrder = [ 'ssh', ]; -export interface FieldPropTypes { - required: boolean; - hasTooltip?: boolean; - tooltipText?: (value: any) => string; - placeholder?: string; - onParametersChange: (value: any) => string; - onParametersUploadFileChange: (value: any) => string; - changeMethods: { onParametersChange: (value: any) => string } & { - onChange: (value: any) => string; - } & { - onQueryChange: (value: any) => string; - } & { onParametersUploadFileChange: (value: any) => string } & { - onAddTableCatalog: () => void; - onRemoveTableCatalog: (idx: number) => void; - } & { - onExtraInputChange: (value: any) => void; - onSSHTunnelParametersChange: (value: any) => string; - }; - validationErrors: JsonObject | null; - getValidation: () => void; - clearValidationErrors: () => void; - db?: DatabaseObject; - field: string; - isEditMode?: boolean; - sslForced?: boolean; - defaultDBName?: string; - editNewDb?: boolean; -} +const extensionsRegistry = getExtensionsRegistry(); + +const SSHTunnelSwitchComponent = + extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch; const FORM_FIELD_MAP = { host: hostField, @@ -105,7 +85,7 @@ const FORM_FIELD_MAP = { warehouse: validatedInputField, role: validatedInputField, account: validatedInputField, - ssh: SSHTunnelSwitch, + ssh: SSHTunnelSwitchComponent, }; interface DatabaseConnectionFormProps { @@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps { } const DatabaseConnectionForm = ({ - dbModel: { parameters }, + dbModel, db, editNewDb, getPlaceholder, @@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({ sslForced, validationErrors, clearValidationErrors, -}: DatabaseConnectionFormProps) => ( -
-
[ - formScrollableStyles, - validatedFormStyles(theme), - ]} - > - {parameters && - FormFieldOrder.filter( - (key: string) => - Object.keys(parameters.properties).includes(key) || - key === 'database_name', - ).map(field => - FORM_FIELD_MAP[field]({ - required: parameters.required?.includes(field), - changeMethods: { - onParametersChange, - onChange, - onQueryChange, - onParametersUploadFileChange, - onAddTableCatalog, - onRemoveTableCatalog, - onExtraInputChange, - }, - validationErrors, - getValidation, - clearValidationErrors, - db, - key: field, - field, - isEditMode, - sslForced, - editNewDb, - placeholder: getPlaceholder ? getPlaceholder(field) : undefined, - }), - )} -
-
-); +}: DatabaseConnectionFormProps) => { + const parameters = dbModel?.parameters; + + return ( +
+
[ + formScrollableStyles, + validatedFormStyles(theme), + ]} + > + {parameters && + FormFieldOrder.filter( + (key: string) => + Object.keys(parameters.properties).includes(key) || + key === 'database_name', + ).map(field => + FORM_FIELD_MAP[field]({ + required: parameters.required?.includes(field), + changeMethods: { + onParametersChange, + onChange, + onQueryChange, + onParametersUploadFileChange, + onAddTableCatalog, + onRemoveTableCatalog, + onExtraInputChange, + }, + validationErrors, + getValidation, + clearValidationErrors, + db, + key: field, + field, + isEditMode, + sslForced, + editNewDb, + placeholder: getPlaceholder ? getPlaceholder(field) : undefined, + }), + )} +
+
+ ); +}; export const FormFieldMap = FORM_FIELD_MAP; export default DatabaseConnectionForm; diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx index 7823d82faf..e0d1b16ff2 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import React, { EventHandler, ChangeEvent, useState } from 'react'; +import React, { useState } from 'react'; import { t, styled } from '@superset-ui/core'; import { AntdForm, Col, Row } from 'src/components'; import { Form, FormLabel } from 'src/components/Form'; @@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio'; import { Input, TextArea } from 'src/components/Input'; import { Input as AntdInput, Tooltip } from 'antd'; import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons'; -import { DatabaseObject } from '../types'; +import { DatabaseObject, FieldPropTypes } from '../types'; import { AuthType } from '.'; const StyledDiv = styled.div` @@ -54,9 +54,7 @@ const SSHTunnelForm = ({ setSSHTunnelLoginMethod, }: { db: DatabaseObject | null; - onSSHTunnelParametersChange: EventHandler< - ChangeEvent - >; + onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange']; setSSHTunnelLoginMethod: (method: AuthType) => void; }) => { const [usePassword, setUsePassword] = useState(AuthType.Password); @@ -86,9 +84,9 @@ const SSHTunnelForm = ({ diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx new file mode 100644 index 0000000000..fef205acf2 --- /dev/null +++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import React from 'react'; +import { render, screen } from 'spec/helpers/testing-library'; +import userEvent from '@testing-library/user-event'; +import SSHTunnelSwitch from './SSHTunnelSwitch'; +import { DatabaseForm, DatabaseObject } from '../types'; + +jest.mock('@superset-ui/core', () => ({ + ...jest.requireActual('@superset-ui/core'), + isFeatureEnabled: jest.fn().mockReturnValue(true), +})); + +jest.mock('src/components', () => ({ + AntdSwitch: ({ + checked, + onChange, + }: { + checked: boolean; + onChange: (checked: boolean) => void; + }) => ( + + ), +})); + +const mockChangeMethods = { + onParametersChange: jest.fn(), +}; + +const mockDbModel = { + engine: 'mysql', + engine_information: { + disable_ssh_tunneling: false, + }, +} as DatabaseForm; + +const defaultDb = { + parameters: { ssh: false }, + ssh_tunnel: {}, + engine: 'mysql', +} as DatabaseObject; + +afterEach(() => { + jest.clearAllMocks(); +}); + +test('Renders SSH Tunnel switch enabled by default and toggles its state', () => { + render( + , + ); + const switchButton = screen.getByRole('switch'); + expect(switchButton).toHaveTextContent('OFF'); + userEvent.click(switchButton); + expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({ + target: { type: 'toggle', name: 'ssh', checked: true, value: true }, + }); + expect(switchButton).toHaveTextContent('ON'); +}); + +test('Does not render if SSH Tunnel is disabled', () => { + render( + , + ); + expect(screen.queryByRole('switch')).not.toBeInTheDocument(); +}); + +test('Checks the switch based on db.parameters.ssh', () => { + const dbWithSSHTunnelEnabled = { + ...defaultDb, + parameters: { ssh: true }, + } as DatabaseObject; + render( + , + ); + expect(screen.getByRole('switch')).toHaveTextContent('ON'); +}); + +test('Calls onParametersChange with true if SSH Tunnel info exists', () => { + const dbWithSSHTunnelInfo = { + ...defaultDb, + parameters: { ssh: undefined }, + ssh_tunnel: { host: 'example.com' }, + } as DatabaseObject; + render( + , + ); + expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({ + target: { type: 'toggle', name: 'ssh', checked: true, value: true }, + }); +}); + +test('Displays tooltip text on hover over the InfoTooltip', async () => { + const tooltipText = 'SSH Tunnel configuration parameters'; + render( + , + ); + + const infoTooltipTrigger = screen.getByRole('img', { + name: 'info-solid_small', + }); + expect(infoTooltipTrigger).toBeInTheDocument(); + + userEvent.hover(infoTooltipTrigger); + + const tooltip = await screen.findByText(tooltipText); + + expect(tooltip).toBeInTheDocument(); +}); diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx index 388e3c83b1..cf96864a3d 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx @@ -16,35 +16,73 @@ * specific language governing permissions and limitations * under the License. */ -import React from 'react'; -import { t, SupersetTheme, SwitchProps } from '@superset-ui/core'; +import React, { useEffect, useState } from 'react'; +import { + t, + SupersetTheme, + isFeatureEnabled, + FeatureFlag, +} from '@superset-ui/core'; import { AntdSwitch } from 'src/components'; import InfoTooltip from 'src/components/InfoTooltip'; import { isEmpty } from 'lodash'; -import { ActionType } from '.'; import { infoTooltip, toggleStyle } from './styles'; +import { SwitchProps } from '../types'; const SSHTunnelSwitch = ({ - isEditMode, - dbFetched, - useSSHTunneling, - setUseSSHTunneling, - setDB, - isSSHTunneling, -}: SwitchProps) => - isSSHTunneling ? ( + clearValidationErrors, + changeMethods, + db, + dbModel, +}: SwitchProps) => { + const [isChecked, setChecked] = useState(false); + const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling); + const disableSSHTunnelingForEngine = + dbModel?.engine_information?.disable_ssh_tunneling || false; + const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine; + + const handleOnChange = (changed: boolean) => { + setChecked(changed); + changeMethods.onParametersChange({ + target: { + type: 'toggle', + name: 'ssh', + checked: true, + value: changed, + }, + }); + clearValidationErrors(); + }; + + useEffect(() => { + if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) { + setChecked(db.parameters.ssh); + } + }, [db?.parameters?.ssh, isSSHTunnelEnabled]); + + useEffect(() => { + if ( + isSSHTunnelEnabled && + db?.parameters?.ssh === undefined && + !isEmpty(db?.ssh_tunnel) + ) { + // reflecting the state of the ssh tunnel on first load + changeMethods.onParametersChange({ + target: { + type: 'toggle', + name: 'ssh', + checked: true, + value: true, + }, + }); + } + }, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]); + + return isSSHTunnelEnabled ? (
infoTooltip(theme)}> { - setUseSSHTunneling(changed); - if (!changed) { - setDB({ - type: ActionType.RemoveSSHTunnelConfig, - }); - } - }} + checked={isChecked} + onChange={handleOnChange} data-test="ssh-tunnel-switch" /> {t('SSH Tunnel')} @@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({ />
) : null; +}; + export default SSHTunnelSwitch; diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx index 0f60857f06..7e8018b25f 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ + +// TODO: These tests should be made atomic in separate files + import React from 'react'; import fetchMock from 'fetch-mock'; import userEvent from '@testing-library/user-event'; @@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => { const SSHTunnelServerPortInput = screen.getByTestId( 'ssh-tunnel-server_port-input', ); - expect(SSHTunnelServerPortInput).toHaveValue(''); + expect(SSHTunnelServerPortInput).toHaveValue(null); userEvent.type(SSHTunnelServerPortInput, '22'); - expect(SSHTunnelServerPortInput).toHaveValue('22'); + expect(SSHTunnelServerPortInput).toHaveValue(22); const SSHTunnelUsernameInput = screen.getByTestId( 'ssh-tunnel-username-input', ); @@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => { const SSHTunnelServerPortInput = screen.getByTestId( 'ssh-tunnel-server_port-input', ); - expect(SSHTunnelServerPortInput).toHaveValue(''); + expect(SSHTunnelServerPortInput).toHaveValue(null); userEvent.type(SSHTunnelServerPortInput, '22'); - expect(SSHTunnelServerPortInput).toHaveValue('22'); + expect(SSHTunnelServerPortInput).toHaveValue(22); const SSHTunnelUsernameInput = screen.getByTestId( 'ssh-tunnel-username-input', ); diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index 60ae032feb..47c9a8b658 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -20,8 +20,6 @@ import { t, styled, SupersetTheme, - FeatureFlag, - isFeatureEnabled, getExtensionsRegistry, } from '@superset-ui/core'; import React, { @@ -31,6 +29,7 @@ import React, { useState, useReducer, Reducer, + useCallback, } from 'react'; import { useHistory } from 'react-router-dom'; import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers'; @@ -65,6 +64,7 @@ import { CatalogObject, Engines, ExtraJson, + CustomTextType, } from '../types'; import ExtraOptions from './ExtraOptions'; import SqlAlchemyForm from './SqlAlchemyForm'; @@ -208,8 +208,8 @@ export type DBReducerActionType = | { type: | ActionType.Reset - | ActionType.AddTableCatalogSheet - | ActionType.RemoveSSHTunnelConfig; + | ActionType.RemoveSSHTunnelConfig + | ActionType.AddTableCatalogSheet; } | { type: ActionType.RemoveTableCatalogSheet; @@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent = ({ const SSHTunnelSwitchComponent = extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch; - const [useSSHTunneling, setUseSSHTunneling] = useState(false); + const [useSSHTunneling, setUseSSHTunneling] = useState( + undefined, + ); let dbConfigExtraExtension = extensionsRegistry.get( 'databaseconnection.extraOption', @@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent = ({ const dbImages = getDatabaseImages(); const connectionAlert = getConnectionAlert(); const isEditMode = !!databaseId; - const disableSSHTunnelingForEngine = ( - availableDbs?.databases?.find( - (DB: DatabaseObject) => - DB.backend === db?.engine || DB.engine === db?.engine, - ) as DatabaseObject - )?.engine_information?.disable_ssh_tunneling; - const isSSHTunneling = - isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine; const hasAlert = connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]); const useSqlAlchemyForm = @@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent = ({ extra: db?.extra, masked_encrypted_extra: db?.masked_encrypted_extra || '', server_cert: db?.server_cert || undefined, - ssh_tunnel: db?.ssh_tunnel || undefined, + ssh_tunnel: + !isEmpty(db?.ssh_tunnel) && useSSHTunneling + ? { + ...db.ssh_tunnel, + server_port: Number(db.ssh_tunnel!.server_port), + } + : undefined, }; setTestInProgress(true); testDatabaseConnection( @@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent = ({ return false; }; + const onChange = useCallback( + ( + type: DBReducerActionType['type'], + payload: CustomTextType | DBReducerPayloadType, + ) => { + setDB({ type, payload } as DBReducerActionType); + }, + [], + ); + + const handleClearValidationErrors = useCallback(() => { + setValidationErrors(null); + }, [setValidationErrors]); + + const handleParametersChange = useCallback( + ({ target }: { target: HTMLInputElement }) => { + onChange(ActionType.ParametersChange, { + type: target.type, + name: target.name, + checked: target.checked, + value: target.value, + }); + }, + [onChange], + ); + const onClose = () => { setDB({ type: ActionType.Reset }); setHasConnectedDb(false); - setValidationErrors(null); // reset validation errors on close + handleClearValidationErrors(); // reset validation errors on close clearError(); setEditNewDb(false); setFileList([]); @@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent = ({ setSSHTunnelPrivateKeys({}); setSSHTunnelPrivateKeyPasswords({}); setConfirmedOverwrite(false); - setUseSSHTunneling(false); + setUseSSHTunneling(undefined); onHide(); }; @@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent = ({ setImportingErrorMessage(msg); }); - const onChange = (type: any, payload: any) => { - setDB({ type, payload } as DBReducerActionType); - }; - const onSave = async () => { let dbConfigExtraExtensionOnSaveError; + + setLoading(true); + dbConfigExtraExtension ?.onSave(extraExtensionComponentState, db) .then(({ error }: { error: any }) => { @@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent = ({ addDangerToast(error); } }); + if (dbConfigExtraExtensionOnSaveError) { setLoading(false); return; @@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent = ({ }); } - // only do validation for non ssh tunnel connections - if (!dbToUpdate?.ssh_tunnel) { - // make sure that button spinner animates - setLoading(true); - const errors = await getValidation(dbToUpdate, true); - if ((validationErrors && !isEmpty(validationErrors)) || errors) { - setLoading(false); - return; - } - // end spinner animation + const errors = await getValidation(dbToUpdate, true); + if (!isEmpty(validationErrors) || errors?.length) { + addDangerToast( + t('Connection failed, please check your connection settings.'), + ); setLoading(false); + return; } const parameters_schema = isEditMode @@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent = ({ }); } - setLoading(true); + // strictly checking for false as an indication that the toggle got unchecked + if (useSSHTunneling === false) { + // remove ssh tunnel + dbToUpdate.ssh_tunnel = null; + } + if (db?.id) { const result = await updateResource( db.id as number, @@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent = ({ }, [sshPrivateKeyPasswordNeeded]); useEffect(() => { - if (db && isSSHTunneling) { - setUseSSHTunneling(!isEmpty(db?.ssh_tunnel)); + if (db?.parameters?.ssh !== undefined) { + setUseSSHTunneling(db.parameters.ssh); } - }, [db, isSSHTunneling]); + }, [db?.parameters?.ssh]); const onDbImport = async (info: UploadChangeParam) => { setImportingErrorMessage(''); @@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent = ({ const renderSSHTunnelForm = () => ( + onSSHTunnelParametersChange={({ target }) => { onChange(ActionType.ParametersSSHTunnelChange, { type: target.type, name: target.name, value: target.value, - }) - } + }); + handleClearValidationErrors(); + }} setSSHTunnelLoginMethod={(method: AuthType) => setDB({ type: ActionType.SetSSHTunnelLoginMethod, @@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent = ({ payload: { indexToDelete: idx }, }); }} - onParametersChange={({ target }: { target: HTMLInputElement }) => - onChange(ActionType.ParametersChange, { - type: target.type, - name: target.name, - checked: target.checked, - value: target.value, - }) - } + onParametersChange={handleParametersChange} onChange={({ target }: { target: HTMLInputElement }) => onChange(ActionType.TextChange, { name: target.name, @@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent = ({ getValidation={() => getValidation(db)} validationErrors={validationErrors} getPlaceholder={getPlaceholder} - clearValidationErrors={() => setValidationErrors(null)} + clearValidationErrors={handleClearValidationErrors} /> - {db?.parameters?.ssh && ( + {useSSHTunneling && ( {renderSSHTunnelForm()} )} @@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent = ({ testInProgress={testInProgress} > {useSSHTunneling && renderSSHTunnelForm()} diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts index 50e535f9b1..58d533c7be 100644 --- a/superset-frontend/src/features/databases/types.ts +++ b/superset-frontend/src/features/databases/types.ts @@ -1,3 +1,7 @@ +import { JsonObject } from '@superset-ui/core'; +import { InputProps } from 'antd/lib/input'; +import { ChangeEvent, EventHandler, FormEvent } from 'react'; + /** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -108,7 +112,7 @@ export type DatabaseObject = { }; // SSH Tunnel information - ssh_tunnel?: SSHTunnelObject; + ssh_tunnel?: SSHTunnelObject | null; }; export type DatabaseForm = { @@ -195,6 +199,10 @@ export type DatabaseForm = { }; preferred: boolean; sqlalchemy_uri_placeholder: string; + engine_information: { + supports_file_upload: boolean; + disable_ssh_tunneling: boolean; + }; }; // the values should align with the database @@ -231,3 +239,73 @@ export interface ExtraJson { }; version?: string; } + +export type CustomTextType = { + value?: string | boolean | number; + type?: string | null; + name?: string; + checked?: boolean; +}; + +type CustomHTMLInputElement = Omit, 'value' | 'type'> & + CustomTextType; + +type CustomHTMLTextAreaElement = Omit< + Partial, + 'value' | 'type' +> & + CustomTextType; + +export type CustomParametersChangeType = + | FormEvent + | { target: T }; + +export type CustomEventHandlerType = EventHandler< + ChangeEvent +>; + +export interface FieldPropTypes { + required: boolean; + hasTooltip?: boolean; + tooltipText?: (value: any) => string; + placeholder?: string; + onParametersChange: (event: CustomParametersChangeType) => void; + onParametersUploadFileChange: (value: any) => string; + changeMethods: { + onParametersChange: (event: CustomParametersChangeType) => void; + } & { + onChange: (value: any) => string; + } & { + onQueryChange: (value: any) => string; + } & { onParametersUploadFileChange: (value: any) => string } & { + onAddTableCatalog: () => void; + onRemoveTableCatalog: (idx: number) => void; + } & { + onExtraInputChange: (value: any) => void; + onSSHTunnelParametersChange: CustomEventHandlerType; + }; + validationErrors: JsonObject | null; + getValidation: () => void; + clearValidationErrors: () => void; + db?: DatabaseObject; + dbModel?: DatabaseForm; + field: string; + isEditMode?: boolean; + sslForced?: boolean; + defaultDBName?: string; + editNewDb?: boolean; +} + +type ChangeMethodsType = FieldPropTypes['changeMethods']; + +// changeMethods compatibility with dynamic forms +type SwitchPropsChangeMethodsType = { + onParametersChange: ChangeMethodsType['onParametersChange']; +}; + +export type SwitchProps = { + dbModel: DatabaseForm; + db: DatabaseObject; + changeMethods: SwitchPropsChangeMethodsType; + clearValidationErrors: () => void; +}; diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts index 85f7c60252..8f31f2fcdd 100644 --- a/superset-frontend/src/views/CRUD/hooks.ts +++ b/superset-frontend/src/views/CRUD/hooks.ts @@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart'; import copyTextToClipboard from 'src/utils/copy'; import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import SupersetText from 'src/utils/textUtils'; -import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types'; +import { DatabaseObject } from 'src/features/databases/types'; +import { FavoriteStatus, ImportResourceName } from './types'; interface ListViewResourceState { loading: boolean; @@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () => SupersetText.DB_CONNECTION_DOC_LINKS; export const testDatabaseConnection = ( - connection: DatabaseObject, + connection: Partial, handleErrorMsg: (errorMsg: string) => void, addSuccessToast: (arg0: string) => void, ) => { @@ -745,7 +746,7 @@ export function useDatabaseValidation() { const getValidation = useCallback( (database: Partial | null, onCreate = false) => { if (database?.parameters?.ssh) { - // when ssh tunnel is enabled we don't want to render any validation errors + // TODO: /validate_parameters/ and related utils should support ssh tunnel setValidationErrors(null); return []; } diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index cde9dd8e88..9efb39b75a 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -19,6 +19,7 @@ from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model +from flask_babel import gettext as _ from marshmallow import ValidationError from superset import is_feature_enabled @@ -33,6 +34,7 @@ from superset.commands.database.exceptions import ( from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, SSHTunnelingNotEnabledError, SSHTunnelInvalidError, ) @@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand): try: # Test connection before starting create transaction TestConnectionDatabaseCommand(self._properties).run() - except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex: + except ( + SupersetErrorsException, + SSHTunnelingNotEnabledError, + SSHTunnelDatabasePortError, + ) as ex: event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], @@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand): SSHTunnelInvalidError, SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError, + SSHTunnelDatabasePortError, ) as ex: db.session.rollback() event_logger.log_with_context( @@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand): # Check database_name uniqueness if not DatabaseDAO.validate_uniqueness(database_name): exceptions.append(DatabaseExistsValidationError()) + if exceptions: exception = DatabaseInvalidError() exception.extend(exceptions) diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index cbfee3ce2a..287accc5aa 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -23,11 +23,13 @@ from marshmallow import ValidationError from superset.commands.base import BaseCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, SSHTunnelInvalidError, SSHTunnelRequiredFieldValidationError, ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError +from superset.databases.utils import make_url_safe from superset.extensions import event_logger from superset.models.core import Database @@ -35,9 +37,12 @@ logger = logging.getLogger(__name__) class CreateSSHTunnelCommand(BaseCommand): + _database: Database + def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() self._properties["database"] = database + self._database = database def run(self) -> Model: try: @@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand): server_address: Optional[str] = self._properties.get("server_address") server_port: Optional[int] = self._properties.get("server_port") username: Optional[str] = self._properties.get("username") + password: Optional[str] = self._properties.get("password") private_key: Optional[str] = self._properties.get("private_key") private_key_password: Optional[str] = self._properties.get( "private_key_password" ) + url = make_url_safe(self._database.sqlalchemy_uri) + if not url.port: + raise SSHTunnelDatabasePortError() if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) if not server_port: exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) if not username: exceptions.append(SSHTunnelRequiredFieldValidationError("username")) + if not private_key and not password: + exceptions.append(SSHTunnelRequiredFieldValidationError("password")) if private_key_password and private_key is None: exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) if exceptions: diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py index 0e3f91cae6..a0def8c087 100644 --- a/superset/commands/database/ssh_tunnel/exceptions.py +++ b/superset/commands/database/ssh_tunnel/exceptions.py @@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError): message = _("SSH Tunnel parameters are invalid.") +class SSHTunnelDatabasePortError(CommandInvalidError): + message = _("A database port is required when connecting via SSH Tunnel.") + + class SSHTunnelUpdateFailedError(UpdateFailedError): message = _("SSH Tunnel could not be updated.") diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py index ae7ee78afe..d0dd14a5b2 100644 --- a/superset/commands/database/ssh_tunnel/update.py +++ b/superset/commands/database/ssh_tunnel/update.py @@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model from superset.commands.base import BaseCommand from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, SSHTunnelInvalidError, SSHTunnelNotFoundError, SSHTunnelRequiredFieldValidationError, @@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOUpdateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.databases.utils import make_url_safe logger = logging.getLogger(__name__) @@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand): self._model_id = model_id self._model: Optional[SSHTunnel] = None - def run(self) -> Model: + def run(self) -> Optional[Model]: self.validate() try: - if self._model is not None: # So we dont get incompatible types error - tunnel = SSHTunnelDAO.update(self._model, self._properties) + if self._model is None: + return None + + # unset password if private key is provided + if self._properties.get("private_key"): + self._properties["password"] = None + + # unset private key and password if password is provided + if self._properties.get("password"): + self._properties["private_key"] = None + self._properties["private_key_password"] = None + + tunnel = SSHTunnelDAO.update(self._model, self._properties) + return tunnel except DAOUpdateFailedError as ex: raise SSHTunnelUpdateFailedError() from ex - return tunnel def validate(self) -> None: # Validate/populate model exists self._model = SSHTunnelDAO.find_by_id(self._model_id) if not self._model: raise SSHTunnelNotFoundError() + + url = make_url_safe(self._model.database.sqlalchemy_uri) private_key: Optional[str] = self._properties.get("private_key") private_key_password: Optional[str] = self._properties.get( "private_key_password" @@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand): raise SSHTunnelInvalidError( exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] ) + if not url.port: + raise SSHTunnelDatabasePortError() diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 0ffdf3ddd9..431918c6bc 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -32,7 +32,10 @@ from superset.commands.database.exceptions import ( DatabaseTestConnectionDriverError, DatabaseTestConnectionUnexpectedError, ) -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelingNotEnabledError, +) from superset.daos.database import DatabaseDAO, SSHTunnelDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -61,20 +64,22 @@ def get_log_connection_action( class TestConnectionDatabaseCommand(BaseCommand): + _model: Optional[Database] = None + _context: dict[str, Any] + _uri: str + def __init__(self, data: dict[str, Any]): self._properties = data.copy() - self._model: Optional[Database] = None - def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches - self.validate() - ex_str = "" + if (database_name := self._properties.get("database_name")) is not None: + self._model = DatabaseDAO.get_database_by_name(database_name) + uri = self._properties.get("sqlalchemy_uri", "") if self._model and uri == self._model.safe_sqlalchemy_uri(): uri = self._model.sqlalchemy_uri_decrypted - ssh_tunnel = self._properties.get("ssh_tunnel") - # context for error messages url = make_url_safe(uri) + context = { "hostname": url.host, "password": url.password, @@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand): "database": url.database, } + self._context = context + self._uri = uri + + def run(self) -> None: # pylint: disable=too-many-statements + self.validate() + ex_str = "" + ssh_tunnel = self._properties.get("ssh_tunnel") + serialized_encrypted_extra = self._properties.get( "masked_encrypted_extra", "{}", @@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand): encrypted_extra=serialized_encrypted_extra, ) - database.set_sqlalchemy_uri(uri) + database.set_sqlalchemy_uri(self._uri) database.db_engine_spec.mutate_db_for_connection_test(database) # Generate tunnel if present in the properties if ssh_tunnel: - if not is_feature_enabled("SSH_TUNNELING"): - raise SSHTunnelingNotEnabledError() - # If there's an existing tunnel for that DB we need to use the stored - # password, private_key and private_key_password instead + # unmask password while allowing for updated values if ssh_tunnel_id := ssh_tunnel.pop("id", None): if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id): ssh_tunnel = unmask_password_info( @@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand): engine=database.db_engine_spec.__name__, ) # check for custom errors (wrong username, wrong password, etc) - errors = database.db_engine_spec.extract_errors(ex, context) + errors = database.db_engine_spec.extract_errors(ex, self._context) raise SupersetErrorsException(errors) from ex except SupersetSecurityException as ex: event_logger.log_with_context( @@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand): ), engine=database.db_engine_spec.__name__, ) - errors = database.db_engine_spec.extract_errors(ex, context) + errors = database.db_engine_spec.extract_errors(ex, self._context) raise DatabaseTestConnectionUnexpectedError(errors) from ex def validate(self) -> None: - if (database_name := self._properties.get("database_name")) is not None: - self._model = DatabaseDAO.get_database_by_name(database_name) + if self._properties.get("ssh_tunnel"): + if not is_feature_enabled("SSH_TUNNELING"): + raise SSHTunnelingNotEnabledError() + if not self._context.get("port"): + raise SSHTunnelDatabasePortError() diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index edc0ba1b98..5575d674a8 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -18,6 +18,7 @@ import logging from typing import Any, Optional from flask_appbuilder.models.sqla import Model +from flask_babel import gettext as _ from marshmallow import ValidationError from superset import is_feature_enabled @@ -30,8 +31,11 @@ from superset.commands.database.exceptions import ( DatabaseUpdateFailedError, ) from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand +from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, + SSHTunnelDeleteFailedError, SSHTunnelingNotEnabledError, SSHTunnelInvalidError, SSHTunnelUpdateFailedError, @@ -47,15 +51,21 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): + _model: Optional[Database] + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None - def run(self) -> Model: - self.validate() + def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches + self._model = DatabaseDAO.find_by_id(self._model_id) + if not self._model: raise DatabaseNotFoundError() + + self.validate() + old_database_name = self._model.database_name # unmask ``encrypted_extra`` @@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand): database = DatabaseDAO.update(self._model, self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + + if "ssh_tunnel" in self._properties: if not is_feature_enabled("SSH_TUNNELING"): db.session.rollback() raise SSHTunnelingNotEnabledError() - existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id) - if existing_ssh_tunnel_model is None: - # We couldn't found an existing tunnel so we need to create one + + if self._properties.get("ssh_tunnel") is None and ssh_tunnel: + # We need to remove the existing tunnel try: - CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() - except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: - # So we can show the original message - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - else: - # We found an existing tunnel so we need to update it - try: - UpdateSSHTunnelCommand( - existing_ssh_tunnel_model.id, ssh_tunnel_properties - ).run() - except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex: - # So we can show the original message + DeleteSSHTunnelCommand(ssh_tunnel.id).run() + ssh_tunnel = None + except SSHTunnelDeleteFailedError as ex: raise ex except Exception as ex: raise DatabaseUpdateFailedError() from ex + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + if ssh_tunnel is None: + # We couldn't found an existing tunnel so we need to create one + try: + ssh_tunnel = CreateSSHTunnelCommand( + database, ssh_tunnel_properties + ).run() + except ( + SSHTunnelInvalidError, + SSHTunnelCreateFailedError, + SSHTunnelDatabasePortError, + ) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + else: + # We found an existing tunnel so we need to update it + try: + ssh_tunnel_id = ssh_tunnel.id + ssh_tunnel = UpdateSSHTunnelCommand( + ssh_tunnel_id, ssh_tunnel_properties + ).run() + except ( + SSHTunnelInvalidError, + SSHTunnelUpdateFailedError, + SSHTunnelDatabasePortError, + ) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + # adding a new database we always want to force refresh schema list # TODO Improve this simplistic implementation for catching DB conn fails try: - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) except Exception as ex: db.session.rollback() @@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand): def validate(self) -> None: exceptions: list[ValidationError] = [] - # Validate/populate model exists - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() database_name: Optional[str] = self._properties.get("database_name") if database_name: # Check database_name uniqueness diff --git a/superset/databases/api.py b/superset/databases/api.py index ceea8230c1..1e44a52106 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand from superset.commands.database.importers.dispatcher import ImportDatabasesCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, SSHTunnelDeleteFailedError, SSHTunnelingNotEnabledError, ) @@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): exc_info=True, ) return self.response_422(message=str(ex)) - except SSHTunnelingNotEnabledError as ex: + except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex: return self.response_400(message=str(ex)) except SupersetException as ex: return self.response(ex.status, message=ex.message) @@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): exc_info=True, ) return self.response_422(message=str(ex)) - except SSHTunnelingNotEnabledError as ex: + except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex: return self.response_400(message=str(ex)) @expose("/", methods=("DELETE",)) @@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): try: TestConnectionDatabaseCommand(item).run() return self.response(200, message="OK") - except SSHTunnelingNotEnabledError as ex: + except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex: return self.response_400(message=str(ex)) @expose("//related_objects/", methods=("GET",)) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index ebabc16e87..0f9dc03723 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql import func from superset import db, security_manager +from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError from superset.connectors.sqla.models import SqlaTable from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch("superset.commands.database.create.is_feature_enabled") + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_create_database_with_missing_port_raises_error( + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_get_all_schema_names, + ): + """ + Database API: Test that missing port raises SSHTunnelDatabaseError + """ + mock_create_is_feature_enabled.return_value = True + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": modified_sqlalchemy_uri, + "ssh_tunnel": ssh_tunnel_properties, + } + + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": modified_sqlalchemy_uri, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual( + response.get("message"), + "A database port is required when connecting via SSH Tunnel.", + ) + @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch("superset.commands.database.create.is_feature_enabled") + @mock.patch("superset.commands.database.update.is_feature_enabled") + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_database_with_missing_port_raises_error( + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_update_is_feature_enabled, + mock_get_all_schema_names, + ): + """ + Database API: Test that missing port raises SSHTunnelDatabaseError + """ + mock_create_is_feature_enabled.return_value = True + mock_update_is_feature_enabled.return_value = True + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": modified_sqlalchemy_uri, + "ssh_tunnel": ssh_tunnel_properties, + } + + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response_create = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/{}".format(response_create.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual( + response.get("message"), + "A database port is required when connecting via SSH Tunnel.", + ) + + # Cleanup + model = db.session.query(Database).get(response_create.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch("superset.commands.database.create.is_feature_enabled") + @mock.patch("superset.commands.database.update.is_feature_enabled") + @mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled") + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_delete_ssh_tunnel( + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_update_is_feature_enabled, + mock_delete_is_feature_enabled, + mock_get_all_schema_names, + ): + """ + Database API: Test deleting a SSH tunnel via Database update + """ + mock_create_is_feature_enabled.return_value = True + mock_update_is_feature_enabled.return_value = True + mock_delete_is_feature_enabled.return_value = True + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + + database_data_with_ssh_tunnel_null = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": None, + } + + rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one_or_none() + ) + + assert model_ssh_tunnel is None + + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 4b05cce637..c80b52931d 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -19,7 +19,10 @@ import pytest from sqlalchemy.orm.session import Session -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelInvalidError, +) def test_create_ssh_tunnel_command() -> None: @@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database - database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database( + id=1, + database_name="my_database", + sqlalchemy_uri="postgresql://u:p@localhost:5432/db", + ) properties = { "database_id": database.id, @@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database - database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database( + id=1, + database_name="my_database", + sqlalchemy_uri="postgresql://u:p@localhost:5432/db", + ) # If we are trying to create a tunnel with a private_key_password # then a private_key is mandatory @@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: with pytest.raises(SSHTunnelInvalidError) as excinfo: command.run() assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") + + +def test_create_ssh_tunnel_command_no_port() -> None: + from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + database = Database( + id=1, + database_name="my_database", + sqlalchemy_uri="postgresql://u:p@localhost/db", + ) + + properties = { + "database": database, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + command = CreateSSHTunnelCommand(database, properties) + + with pytest.raises(SSHTunnelDatabasePortError) as excinfo: + command.run() + assert str(excinfo.value) == ( + "A database port is required when connecting via SSH Tunnel." + ) diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 54e54d05da..66684eb8de 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -20,11 +20,14 @@ from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError +from superset.commands.database.ssh_tunnel.exceptions import ( + SSHTunnelDatabasePortError, + SSHTunnelInvalidError, +) @pytest.fixture -def session_with_data(session: Session) -> Iterator[Session]: +def session_with_data(request, session: Session) -> Iterator[Session]: from superset.connectors.sqla.models import SqlaTable from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database @@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db") + database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri) sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], @@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None: with pytest.raises(SSHTunnelInvalidError) as excinfo: command.run() assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") + + +@pytest.mark.parametrize( + "session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True +) +def test_update_shh_tunnel_no_port(session_with_data: Session) -> None: + from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand + from superset.daos.database import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address + + update_payload = {"server_address": "Test update"} + command = UpdateSSHTunnelCommand(1, update_payload) + + with pytest.raises(SSHTunnelDatabasePortError) as excinfo: + command.run() + assert str(excinfo.value) == ( + "A database port is required when connecting via SSH Tunnel." + )