diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e7d79e405..4e5a860fd3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -623,6 +623,11 @@ cd superset-frontend npm run test ``` +To run a single test file: +```bash +npm run test -- path/to/file.js +``` + ### Integration Testing We use [Cypress](https://www.cypress.io/) for integration tests. Tests can be run by `tox -e cypress`. To open Cypress and explore tests first setup and run test server: diff --git a/UPDATING.md b/UPDATING.md index 79777b79f2..9e92d36c46 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,7 +23,7 @@ This file documents any backwards-incompatible changes in Superset and assists people when migrating to a new version. ## Next - +- [11499](https://github.com/apache/incubator-superset/pull/11499): Breaking change: `STORE_CACHE_KEYS_IN_METADATA_DB` config flag added (default=`False`) to write `CacheKey` records to the metadata DB. `CacheKey` recording was enabled by default previously. - [11920](https://github.com/apache/incubator-superset/pull/11920): Undos the DB migration from [11714](https://github.com/apache/incubator-superset/pull/11714) to prevent adding new columns to the logs table. Deploying a sha between these two PRs may result in locking your DB. - [11704](https://github.com/apache/incubator-superset/pull/11704) Breaking change: Jinja templating for SQL queries has been updated, removing default modules such as `datetime` and `random` and enforcing static template values. To restore or extend functionality, use `JINJA_CONTEXT_ADDONS` and `CUSTOM_TEMPLATE_PROCESSORS` in `superset_config.py`. - [11714](https://github.com/apache/incubator-superset/pull/11714): Logs diff --git a/setup.cfg b/setup.cfg index 28c8e77df6..38856a3504 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,redis,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts b/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts new file mode 100644 index 0000000000..e42ac9152f --- /dev/null +++ b/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts @@ -0,0 +1,265 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import fetchMock from 'fetch-mock'; +import sinon from 'sinon'; +import * as featureFlags from 'src/featureFlags'; +import initAsyncEvents from 'src/middleware/asyncEvent'; + +jest.useFakeTimers(); + +describe('asyncEvent middleware', () => { + const next = sinon.spy(); + const state = { + charts: { + 123: { + id: 123, + status: 'loading', + asyncJobId: 'foo123', + }, + 345: { + id: 345, + status: 'loading', + asyncJobId: 'foo345', + }, + }, + }; + const events = [ + { + status: 'done', + result_url: '/api/v1/chart/data/cache-key-1', + job_id: 'foo123', + channel_id: '999', + errors: [], + }, + { + status: 'done', + result_url: '/api/v1/chart/data/cache-key-2', + job_id: 'foo345', + channel_id: '999', + errors: [], + }, + ]; + const mockStore = { + getState: () => state, + dispatch: sinon.stub(), + }; + const action = { + type: 'GENERIC_ACTION', + }; + const EVENTS_ENDPOINT = 'glob:*/api/v1/async_event/*'; + const CACHED_DATA_ENDPOINT = 'glob:*/api/v1/chart/data/*'; + const config = { + GLOBAL_ASYNC_QUERIES_TRANSPORT: 'polling', + GLOBAL_ASYNC_QUERIES_POLLING_DELAY: 500, + }; + let featureEnabledStub: any; + + function setup() { + const getPendingComponents = sinon.stub(); + const successAction = sinon.spy(); + const errorAction = sinon.spy(); + const testCallback = sinon.stub(); + const testCallbackPromise = sinon.stub(); + testCallbackPromise.returns( + new Promise(resolve => { + testCallback.callsFake(resolve); + }), + ); + + return { + getPendingComponents, + successAction, + errorAction, + testCallback, + testCallbackPromise, + }; + } + + beforeEach(() => { + fetchMock.get(EVENTS_ENDPOINT, { + status: 200, + body: { result: [] }, + }); + fetchMock.get(CACHED_DATA_ENDPOINT, { + status: 200, + body: { result: { some: 'data' } }, + }); + featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled'); + featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(true); + }); + afterEach(() => { + fetchMock.reset(); + next.resetHistory(); + featureEnabledStub.restore(); + }); + afterAll(fetchMock.reset); + + it('should initialize and call next', () => { + const { getPendingComponents, successAction, errorAction } = setup(); + getPendingComponents.returns([]); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + }); + asyncEventMiddleware(mockStore)(next)(action); + expect(next.callCount).toBe(1); + }); + + it('should fetch events when there are pending components', () => { + const { + getPendingComponents, + successAction, + errorAction, + testCallback, + testCallbackPromise, + } = setup(); + getPendingComponents.returns(Object.values(state.charts)); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + processEventsCallback: testCallback, + }); + + asyncEventMiddleware(mockStore)(next)(action); + + return testCallbackPromise().then(() => { + expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1); + }); + }); + + it('should fetch cached when there are successful events', () => { + const { + getPendingComponents, + successAction, + errorAction, + testCallback, + testCallbackPromise, + } = setup(); + fetchMock.reset(); + fetchMock.get(EVENTS_ENDPOINT, { + status: 200, + body: { result: events }, + }); + fetchMock.get(CACHED_DATA_ENDPOINT, { + status: 200, + body: { result: { some: 'data' } }, + }); + getPendingComponents.returns(Object.values(state.charts)); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + processEventsCallback: testCallback, + }); + + asyncEventMiddleware(mockStore)(next)(action); + + return testCallbackPromise().then(() => { + expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1); + expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2); + expect(successAction.callCount).toBe(2); + }); + }); + + it('should call errorAction for cache fetch error responses', () => { + const { + getPendingComponents, + successAction, + errorAction, + testCallback, + testCallbackPromise, + } = setup(); + fetchMock.reset(); + fetchMock.get(EVENTS_ENDPOINT, { + status: 200, + body: { result: events }, + }); + fetchMock.get(CACHED_DATA_ENDPOINT, { + status: 400, + body: { errors: ['error'] }, + }); + getPendingComponents.returns(Object.values(state.charts)); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + processEventsCallback: testCallback, + }); + + asyncEventMiddleware(mockStore)(next)(action); + + return testCallbackPromise().then(() => { + expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1); + expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2); + expect(errorAction.callCount).toBe(2); + }); + }); + + it('should handle event fetching error responses', () => { + const { + getPendingComponents, + successAction, + errorAction, + testCallback, + testCallbackPromise, + } = setup(); + fetchMock.reset(); + fetchMock.get(EVENTS_ENDPOINT, { + status: 400, + body: { message: 'error' }, + }); + getPendingComponents.returns(Object.values(state.charts)); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + processEventsCallback: testCallback, + }); + + asyncEventMiddleware(mockStore)(next)(action); + + return testCallbackPromise().then(() => { + expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1); + }); + }); + + it('should not fetch events when async queries are disabled', () => { + featureEnabledStub.restore(); + featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled'); + featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(false); + const { getPendingComponents, successAction, errorAction } = setup(); + getPendingComponents.returns(Object.values(state.charts)); + const asyncEventMiddleware = initAsyncEvents({ + config, + getPendingComponents, + successAction, + errorAction, + }); + + asyncEventMiddleware(mockStore)(next)(action); + expect(getPendingComponents.called).toBe(false); + }); +}); diff --git a/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts b/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts index 8519b71206..8e89fec284 100644 --- a/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts +++ b/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts @@ -17,7 +17,7 @@ * under the License. */ import { ErrorTypeEnum } from 'src/components/ErrorMessage/types'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; describe('getClientErrorObject()', () => { it('Returns a Promise', () => { diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 7c1bce3e1f..ebfca9116f 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -30,7 +30,7 @@ import { addSuccessToast as addSuccessToastAction, addWarningToast as addWarningToastAction, } from '../../messageToasts/actions/index'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; import COMMON_ERR_MESSAGES from '../../utils/errorMessages'; export const RESET_STATE = 'RESET_STATE'; diff --git a/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx b/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx index dc9fb4267f..c094d08775 100644 --- a/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx +++ b/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx @@ -25,7 +25,7 @@ import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags'; import Button from 'src/components/Button'; import CopyToClipboard from '../../components/CopyToClipboard'; import { storeQuery } from '../../utils/common'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; import withToasts from '../../messageToasts/enhancers/withToasts'; const propTypes = { diff --git a/superset-frontend/src/chart/chartAction.js b/superset-frontend/src/chart/chartAction.js index f6329a80a5..0198b33a3b 100644 --- a/superset-frontend/src/chart/chartAction.js +++ b/superset-frontend/src/chart/chartAction.js @@ -38,7 +38,7 @@ import { import { addDangerToast } from '../messageToasts/actions'; import { logEvent } from '../logger/actions'; import { Logger, LOG_ACTIONS_LOAD_CHART } from '../logger/LogUtils'; -import getClientErrorObject from '../utils/getClientErrorObject'; +import { getClientErrorObject } from '../utils/getClientErrorObject'; import { allowCrossDomain as domainShardingEnabled } from '../utils/hostNamesConfig'; export const CHART_UPDATE_STARTED = 'CHART_UPDATE_STARTED'; @@ -66,6 +66,11 @@ export function chartUpdateFailed(queryResponse, key) { return { type: CHART_UPDATE_FAILED, queryResponse, key }; } +export const CHART_UPDATE_QUEUED = 'CHART_UPDATE_QUEUED'; +export function chartUpdateQueued(asyncJobMeta, key) { + return { type: CHART_UPDATE_QUEUED, asyncJobMeta, key }; +} + export const CHART_RENDERING_FAILED = 'CHART_RENDERING_FAILED'; export function chartRenderingFailed(error, key, stackTrace) { return { type: CHART_RENDERING_FAILED, error, key, stackTrace }; @@ -356,6 +361,12 @@ export function exploreJSON( const chartDataRequestCaught = chartDataRequest .then(response => { + if (isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES)) { + // deal with getChartDataRequest transforming the response data + const result = 'result' in response ? response.result[0] : response; + return dispatch(chartUpdateQueued(result, key)); + } + // new API returns an object with an array of restults // problem: response holds a list of results, when before we were just getting one result. // How to make the entire app compatible with multiple results? diff --git a/superset-frontend/src/chart/chartReducer.js b/superset-frontend/src/chart/chartReducer.js index b3e72124f9..fc28e99d2c 100644 --- a/superset-frontend/src/chart/chartReducer.js +++ b/superset-frontend/src/chart/chartReducer.js @@ -71,6 +71,14 @@ export default function chartReducer(charts = {}, action) { chartUpdateEndTime: now(), }; }, + [actions.CHART_UPDATE_QUEUED](state) { + return { + ...state, + asyncJobId: action.asyncJobMeta.job_id, + chartStatus: 'loading', + chartUpdateEndTime: now(), + }; + }, [actions.CHART_RENDERING_SUCCEEDED](state) { return { ...state, chartStatus: 'rendered', chartUpdateEndTime: now() }; }, diff --git a/superset-frontend/src/components/AsyncSelect.jsx b/superset-frontend/src/components/AsyncSelect.jsx index fc9c5eeb59..93bbb34087 100644 --- a/superset-frontend/src/components/AsyncSelect.jsx +++ b/superset-frontend/src/components/AsyncSelect.jsx @@ -21,7 +21,7 @@ import PropTypes from 'prop-types'; // TODO: refactor this with `import { AsyncSelect } from src/components/Select` import { Select } from 'src/components/Select'; import { t, SupersetClient } from '@superset-ui/core'; -import getClientErrorObject from '../utils/getClientErrorObject'; +import { getClientErrorObject } from '../utils/getClientErrorObject'; const propTypes = { dataEndpoint: PropTypes.string.isRequired, diff --git a/superset-frontend/src/dashboard/actions/dashboardState.js b/superset-frontend/src/dashboard/actions/dashboardState.js index a499fa9a8e..fc31a8501a 100644 --- a/superset-frontend/src/dashboard/actions/dashboardState.js +++ b/superset-frontend/src/dashboard/actions/dashboardState.js @@ -29,7 +29,7 @@ import { updateDirectPathToFilter, } from './dashboardFilters'; import { applyDefaultFormData } from '../../explore/store'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; import { SAVE_TYPE_OVERWRITE } from '../util/constants'; import { addSuccessToast, diff --git a/superset-frontend/src/dashboard/actions/datasources.js b/superset-frontend/src/dashboard/actions/datasources.js index 40cba8559a..4277edc661 100644 --- a/superset-frontend/src/dashboard/actions/datasources.js +++ b/superset-frontend/src/dashboard/actions/datasources.js @@ -17,7 +17,7 @@ * under the License. */ import { SupersetClient } from '@superset-ui/core'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; export const SET_DATASOURCE = 'SET_DATASOURCE'; export function setDatasource(datasource, key) { diff --git a/superset-frontend/src/dashboard/actions/sliceEntities.js b/superset-frontend/src/dashboard/actions/sliceEntities.js index 01512ddb14..69472d9155 100644 --- a/superset-frontend/src/dashboard/actions/sliceEntities.js +++ b/superset-frontend/src/dashboard/actions/sliceEntities.js @@ -22,7 +22,7 @@ import rison from 'rison'; import { addDangerToast } from 'src/messageToasts/actions'; import { getDatasourceParameter } from 'src/modules/utils'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; export const SET_ALL_SLICES = 'SET_ALL_SLICES'; export function setAllSlices(slices) { diff --git a/superset-frontend/src/dashboard/components/PropertiesModal.jsx b/superset-frontend/src/dashboard/components/PropertiesModal.jsx index 019118591d..dd17dcb7d5 100644 --- a/superset-frontend/src/dashboard/components/PropertiesModal.jsx +++ b/superset-frontend/src/dashboard/components/PropertiesModal.jsx @@ -35,7 +35,7 @@ import FormLabel from 'src/components/FormLabel'; import { JsonEditor } from 'src/components/AsyncAceEditor'; import ColorSchemeControlWrapper from 'src/dashboard/components/ColorSchemeControlWrapper'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; import withToasts from '../../messageToasts/enhancers/withToasts'; import '../stylesheets/buttons.less'; diff --git a/superset-frontend/src/dashboard/index.jsx b/superset-frontend/src/dashboard/index.jsx index 3937a357df..9fe82346c3 100644 --- a/superset-frontend/src/dashboard/index.jsx +++ b/superset-frontend/src/dashboard/index.jsx @@ -24,7 +24,9 @@ import { initFeatureFlags } from 'src/featureFlags'; import { initEnhancer } from '../reduxUtils'; import getInitialState from './reducers/getInitialState'; import rootReducer from './reducers/index'; +import initAsyncEvents from '../middleware/asyncEvent'; import logger from '../middleware/loggerMiddleware'; +import * as actions from '../chart/chartAction'; import App from './App'; @@ -33,10 +35,23 @@ const bootstrapData = JSON.parse(appContainer.getAttribute('data-bootstrap')); initFeatureFlags(bootstrapData.common.feature_flags); const initState = getInitialState(bootstrapData); +const asyncEventMiddleware = initAsyncEvents({ + config: bootstrapData.common.conf, + getPendingComponents: ({ charts }) => + Object.values(charts).filter(c => c.chartStatus === 'loading'), + successAction: (componentId, componentData) => + actions.chartUpdateSucceeded(componentData, componentId), + errorAction: (componentId, response) => + actions.chartUpdateFailed(response, componentId), +}); + const store = createStore( rootReducer, initState, - compose(applyMiddleware(thunk, logger), initEnhancer(false)), + compose( + applyMiddleware(thunk, logger, asyncEventMiddleware), + initEnhancer(false), + ), ); ReactDOM.render(, document.getElementById('app')); diff --git a/superset-frontend/src/datasource/ChangeDatasourceModal.tsx b/superset-frontend/src/datasource/ChangeDatasourceModal.tsx index 4f1597be90..81b9b8cb8c 100644 --- a/superset-frontend/src/datasource/ChangeDatasourceModal.tsx +++ b/superset-frontend/src/datasource/ChangeDatasourceModal.tsx @@ -27,7 +27,7 @@ import { Alert, FormControl, FormControlProps } from 'react-bootstrap'; import { SupersetClient, t } from '@superset-ui/core'; import TableView from 'src/components/TableView'; import Modal from 'src/common/components/Modal'; -import getClientErrorObject from '../utils/getClientErrorObject'; +import { getClientErrorObject } from '../utils/getClientErrorObject'; import Loading from '../components/Loading'; import withToasts from '../messageToasts/enhancers/withToasts'; diff --git a/superset-frontend/src/datasource/DatasourceEditor.jsx b/superset-frontend/src/datasource/DatasourceEditor.jsx index aebeaeb808..dc2d44444e 100644 --- a/superset-frontend/src/datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/datasource/DatasourceEditor.jsx @@ -33,7 +33,7 @@ import Loading from 'src/components/Loading'; import TableSelector from 'src/components/TableSelector'; import EditableTitle from 'src/components/EditableTitle'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import CheckboxControl from 'src/explore/components/controls/CheckboxControl'; import TextControl from 'src/explore/components/controls/TextControl'; diff --git a/superset-frontend/src/datasource/DatasourceModal.tsx b/superset-frontend/src/datasource/DatasourceModal.tsx index 0a1b1c8e2b..daf47c2594 100644 --- a/superset-frontend/src/datasource/DatasourceModal.tsx +++ b/superset-frontend/src/datasource/DatasourceModal.tsx @@ -25,7 +25,7 @@ import Modal from 'src/common/components/Modal'; import AsyncEsmComponent from 'src/components/AsyncEsmComponent'; import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import withToasts from 'src/messageToasts/enhancers/withToasts'; const DatasourceEditor = AsyncEsmComponent(() => import('./DatasourceEditor')); diff --git a/superset-frontend/src/explore/components/DataTablesPane.tsx b/superset-frontend/src/explore/components/DataTablesPane.tsx index b1bc9081db..474ef94f83 100644 --- a/superset-frontend/src/explore/components/DataTablesPane.tsx +++ b/superset-frontend/src/explore/components/DataTablesPane.tsx @@ -23,7 +23,7 @@ import Tabs from 'src/common/components/Tabs'; import Loading from 'src/components/Loading'; import TableView, { EmptyWrapperType } from 'src/components/TableView'; import { getChartDataRequest } from 'src/chart/chartAction'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import { CopyToClipboardButton, FilterInput, diff --git a/superset-frontend/src/explore/components/DisplayQueryButton.jsx b/superset-frontend/src/explore/components/DisplayQueryButton.jsx index 1482156098..dbb5d8441b 100644 --- a/superset-frontend/src/explore/components/DisplayQueryButton.jsx +++ b/superset-frontend/src/explore/components/DisplayQueryButton.jsx @@ -30,7 +30,7 @@ import { DropdownButton } from 'react-bootstrap'; import { styled, t } from '@superset-ui/core'; import { Menu } from 'src/common/components'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; import CopyToClipboard from '../../components/CopyToClipboard'; import { getChartDataRequest } from '../../chart/chartAction'; import downloadAsImage from '../../utils/downloadAsImage'; diff --git a/superset-frontend/src/explore/components/PropertiesModal.tsx b/superset-frontend/src/explore/components/PropertiesModal.tsx index 1281bc382f..240d2b6840 100644 --- a/superset-frontend/src/explore/components/PropertiesModal.tsx +++ b/superset-frontend/src/explore/components/PropertiesModal.tsx @@ -32,7 +32,7 @@ import rison from 'rison'; import { t, SupersetClient } from '@superset-ui/core'; import Chart, { Slice } from 'src/types/Chart'; import FormLabel from 'src/components/FormLabel'; -import getClientErrorObject from '../../utils/getClientErrorObject'; +import { getClientErrorObject } from '../../utils/getClientErrorObject'; type PropertiesModalProps = { slice: Slice; diff --git a/superset-frontend/src/explore/index.jsx b/superset-frontend/src/explore/index.jsx index 25a704757a..83e4bc63dc 100644 --- a/superset-frontend/src/explore/index.jsx +++ b/superset-frontend/src/explore/index.jsx @@ -25,6 +25,8 @@ import { initFeatureFlags } from '../featureFlags'; import { initEnhancer } from '../reduxUtils'; import getInitialState from './reducers/getInitialState'; import rootReducer from './reducers/index'; +import initAsyncEvents from '../middleware/asyncEvent'; +import * as actions from '../chart/chartAction'; import App from './App'; @@ -35,10 +37,23 @@ const bootstrapData = JSON.parse( initFeatureFlags(bootstrapData.common.feature_flags); const initState = getInitialState(bootstrapData); +const asyncEventMiddleware = initAsyncEvents({ + config: bootstrapData.common.conf, + getPendingComponents: ({ charts }) => + Object.values(charts).filter(c => c.chartStatus === 'loading'), + successAction: (componentId, componentData) => + actions.chartUpdateSucceeded(componentData, componentId), + errorAction: (componentId, response) => + actions.chartUpdateFailed(response, componentId), +}); + const store = createStore( rootReducer, initState, - compose(applyMiddleware(thunk, logger), initEnhancer(false)), + compose( + applyMiddleware(thunk, logger, asyncEventMiddleware), + initEnhancer(false), + ), ); ReactDOM.render(, document.getElementById('app')); diff --git a/superset-frontend/src/featureFlags.ts b/superset-frontend/src/featureFlags.ts index 1e024a5165..93b909ddcb 100644 --- a/superset-frontend/src/featureFlags.ts +++ b/superset-frontend/src/featureFlags.ts @@ -34,6 +34,7 @@ export enum FeatureFlag { DISPLAY_MARKDOWN_HTML = 'DISPLAY_MARKDOWN_HTML', ESCAPE_MARKDOWN_HTML = 'ESCAPE_MARKDOWN_HTML', VERSIONED_EXPORT = 'VERSIONED_EXPORT', + GLOBAL_ASYNC_QUERIES = 'GLOBAL_ASYNC_QUERIES', } export type FeatureFlagMap = { diff --git a/superset-frontend/src/middleware/asyncEvent.ts b/superset-frontend/src/middleware/asyncEvent.ts new file mode 100644 index 0000000000..637bb1b38d --- /dev/null +++ b/superset-frontend/src/middleware/asyncEvent.ts @@ -0,0 +1,196 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { Middleware, MiddlewareAPI, Dispatch } from 'redux'; +import { makeApi, SupersetClient } from '@superset-ui/core'; +import { SupersetError } from 'src/components/ErrorMessage/types'; +import { isFeatureEnabled, FeatureFlag } from '../featureFlags'; +import { + getClientErrorObject, + parseErrorJson, +} from '../utils/getClientErrorObject'; + +export type AsyncEvent = { + id: string; + channel_id: string; + job_id: string; + user_id: string; + status: string; + errors: SupersetError[]; + result_url: string; +}; + +type AsyncEventOptions = { + config: { + GLOBAL_ASYNC_QUERIES_TRANSPORT: string; + GLOBAL_ASYNC_QUERIES_POLLING_DELAY: number; + }; + getPendingComponents: (state: any) => any[]; + successAction: (componentId: number, componentData: any) => { type: string }; + errorAction: (componentId: number, response: any) => { type: string }; + processEventsCallback?: (events: AsyncEvent[]) => void; // this is currently used only for tests +}; + +type CachedDataResponse = { + componentId: number; + status: string; + data: any; +}; + +const initAsyncEvents = (options: AsyncEventOptions) => { + // TODO: implement websocket support + const TRANSPORT_POLLING = 'polling'; + const { + config, + getPendingComponents, + successAction, + errorAction, + processEventsCallback, + } = options; + const transport = config.GLOBAL_ASYNC_QUERIES_TRANSPORT || TRANSPORT_POLLING; + const polling_delay = config.GLOBAL_ASYNC_QUERIES_POLLING_DELAY || 500; + + const middleware: Middleware = (store: MiddlewareAPI) => (next: Dispatch) => { + const JOB_STATUS = { + PENDING: 'pending', + RUNNING: 'running', + ERROR: 'error', + DONE: 'done', + }; + const LOCALSTORAGE_KEY = 'last_async_event_id'; + const POLLING_URL = '/api/v1/async_event/'; + let lastReceivedEventId: string | null; + + try { + lastReceivedEventId = localStorage.getItem(LOCALSTORAGE_KEY); + } catch (err) { + console.warn('failed to fetch last event Id from localStorage'); + } + + const fetchEvents = makeApi< + { last_id?: string | null }, + { result: AsyncEvent[] } + >({ + method: 'GET', + endpoint: POLLING_URL, + }); + + const fetchCachedData = async ( + asyncEvent: AsyncEvent, + componentId: number, + ): Promise => { + let status = 'success'; + let data; + try { + const { json } = await SupersetClient.get({ + endpoint: asyncEvent.result_url, + }); + data = 'result' in json ? json.result[0] : json; + } catch (response) { + status = 'error'; + data = await getClientErrorObject(response); + } + + return { componentId, status, data }; + }; + + const setLastId = (asyncEvent: AsyncEvent) => { + lastReceivedEventId = asyncEvent.id; + try { + localStorage.setItem(LOCALSTORAGE_KEY, lastReceivedEventId as string); + } catch (err) { + console.warn('Error saving event ID to localStorage', err); + } + }; + + const processEvents = async () => { + const state = store.getState(); + const queuedComponents = getPendingComponents(state); + const eventArgs = lastReceivedEventId + ? { last_id: lastReceivedEventId } + : {}; + const events: AsyncEvent[] = []; + if (queuedComponents && queuedComponents.length) { + try { + const { result: events } = await fetchEvents(eventArgs); + if (events && events.length) { + const componentsByJobId = queuedComponents.reduce((acc, item) => { + acc[item.asyncJobId] = item; + return acc; + }, {}); + const fetchDataEvents: Promise[] = []; + events.forEach((asyncEvent: AsyncEvent) => { + const component = componentsByJobId[asyncEvent.job_id]; + if (!component) { + console.warn( + 'component not found for job_id', + asyncEvent.job_id, + ); + return setLastId(asyncEvent); + } + const componentId = component.id; + switch (asyncEvent.status) { + case JOB_STATUS.DONE: + fetchDataEvents.push( + fetchCachedData(asyncEvent, componentId), + ); + break; + case JOB_STATUS.ERROR: + store.dispatch( + errorAction(componentId, parseErrorJson(asyncEvent)), + ); + break; + default: + console.warn('received event with status', asyncEvent.status); + } + + return setLastId(asyncEvent); + }); + + const fetchResults = await Promise.all(fetchDataEvents); + fetchResults.forEach(result => { + if (result.status === 'success') { + store.dispatch(successAction(result.componentId, result.data)); + } else { + store.dispatch(errorAction(result.componentId, result.data)); + } + }); + } + } catch (err) { + console.warn(err); + } + } + + if (processEventsCallback) processEventsCallback(events); + + return setTimeout(processEvents, polling_delay); + }; + + if ( + isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES) && + transport === TRANSPORT_POLLING + ) + processEvents(); + + return action => next(action); + }; + + return middleware; +}; + +export default initAsyncEvents; diff --git a/superset-frontend/src/setup/setupApp.ts b/superset-frontend/src/setup/setupApp.ts index 1eef937bc7..740c97b355 100644 --- a/superset-frontend/src/setup/setupApp.ts +++ b/superset-frontend/src/setup/setupApp.ts @@ -19,7 +19,8 @@ /* eslint global-require: 0 */ import $ from 'jquery'; import { SupersetClient } from '@superset-ui/core'; -import getClientErrorObject, { +import { + getClientErrorObject, ClientErrorObject, } from '../utils/getClientErrorObject'; import setupErrorMessages from './setupErrorMessages'; diff --git a/superset-frontend/src/utils/common.js b/superset-frontend/src/utils/common.js index 033382294e..2753c0205c 100644 --- a/superset-frontend/src/utils/common.js +++ b/superset-frontend/src/utils/common.js @@ -21,7 +21,7 @@ import { getTimeFormatter, TimeFormats, } from '@superset-ui/core'; -import getClientErrorObject from './getClientErrorObject'; +import { getClientErrorObject } from './getClientErrorObject'; // ATTENTION: If you change any constants, make sure to also change constants.py diff --git a/superset-frontend/src/utils/getClientErrorObject.ts b/superset-frontend/src/utils/getClientErrorObject.ts index 274b8b4d93..269126169f 100644 --- a/superset-frontend/src/utils/getClientErrorObject.ts +++ b/superset-frontend/src/utils/getClientErrorObject.ts @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { SupersetClientResponse, t } from '@superset-ui/core'; +import { JsonObject, SupersetClientResponse, t } from '@superset-ui/core'; import { SupersetError, ErrorTypeEnum, @@ -36,7 +36,33 @@ export type ClientErrorObject = { stacktrace?: string; } & Partial; -export default function getClientErrorObject( +export function parseErrorJson(responseObject: JsonObject): ClientErrorObject { + let error = { ...responseObject }; + // Backwards compatibility for old error renderers with the new error object + if (error.errors && error.errors.length > 0) { + error.error = error.description = error.errors[0].message; + error.link = error.errors[0]?.extra?.link; + } + + if (error.stack) { + error = { + ...error, + error: + t('Unexpected error: ') + + (error.description || t('(no description, click to see stack trace)')), + stacktrace: error.stack, + }; + } else if (error.responseText && error.responseText.indexOf('CSRF') >= 0) { + error = { + ...error, + error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT), + }; + } + + return { ...error, error: error.error }; // explicit ClientErrorObject +} + +export function getClientErrorObject( response: SupersetClientResponse | (Response & { timeout: number }) | string, ): Promise { // takes a SupersetClientResponse as input, attempts to read response as Json if possible, @@ -54,33 +80,8 @@ export default function getClientErrorObject( .clone() .json() .then(errorJson => { - let error = { ...responseObject, ...errorJson }; - - // Backwards compatibility for old error renderers with the new error object - if (error.errors && error.errors.length > 0) { - error.error = error.description = error.errors[0].message; - error.link = error.errors[0]?.extra?.link; - } - - if (error.stack) { - error = { - ...error, - error: - t('Unexpected error: ') + - (error.description || - t('(no description, click to see stack trace)')), - stacktrace: error.stack, - }; - } else if ( - error.responseText && - error.responseText.indexOf('CSRF') >= 0 - ) { - error = { - ...error, - error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT), - }; - } - resolve(error); + const error = { ...responseObject, ...errorJson }; + resolve(parseErrorJson(error)); }) .catch(() => { // fall back to reading as text diff --git a/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx b/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx index dba16e3465..d1b7ef5355 100644 --- a/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx +++ b/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx @@ -29,7 +29,7 @@ import ConfirmStatusChange from 'src/components/ConfirmStatusChange'; import DeleteModal from 'src/components/DeleteModal'; import ListView, { ListViewProps } from 'src/components/ListView'; import SubMenu, { SubMenuProps } from 'src/components/Menu/SubMenu'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import withToasts from 'src/messageToasts/enhancers/withToasts'; import { IconName } from 'src/components/Icon'; import { useListViewResource } from 'src/views/CRUD/hooks'; diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx index fa8a3625ab..b8f061c1b9 100644 --- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx +++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx @@ -21,7 +21,7 @@ import { styled, t, SupersetClient } from '@superset-ui/core'; import InfoTooltip from 'src/common/components/InfoTooltip'; import { useSingleViewResource } from 'src/views/CRUD/hooks'; import withToasts from 'src/messageToasts/enhancers/withToasts'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import Icon from 'src/components/Icon'; import Modal from 'src/common/components/Modal'; import Tabs from 'src/common/components/Tabs'; diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts index e209815aa3..dedcadd87a 100644 --- a/superset-frontend/src/views/CRUD/hooks.ts +++ b/superset-frontend/src/views/CRUD/hooks.ts @@ -25,7 +25,7 @@ import { FetchDataConfig } from 'src/components/ListView'; import { FilterValue } from 'src/components/ListView/types'; import Chart, { Slice } from 'src/types/Chart'; import copyTextToClipboard from 'src/utils/copy'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import { FavoriteStatus } from './types'; interface ListViewResourceState { diff --git a/superset-frontend/src/views/CRUD/utils.tsx b/superset-frontend/src/views/CRUD/utils.tsx index e22310dd4d..ea79121277 100644 --- a/superset-frontend/src/views/CRUD/utils.tsx +++ b/superset-frontend/src/views/CRUD/utils.tsx @@ -25,7 +25,7 @@ import { } from '@superset-ui/core'; import Chart from 'src/types/Chart'; import rison from 'rison'; -import getClientErrorObject from 'src/utils/getClientErrorObject'; +import { getClientErrorObject } from 'src/utils/getClientErrorObject'; import { FetchDataConfig } from 'src/components/ListView'; import { Dashboard } from './types'; diff --git a/superset/app.py b/superset/app.py index bdba4cc7f5..d8f2ad252f 100644 --- a/superset/app.py +++ b/superset/app.py @@ -30,6 +30,7 @@ from superset.extensions import ( _event_logger, APP_DIR, appbuilder, + async_query_manager, cache_manager, celery_app, csrf, @@ -127,6 +128,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-branches from superset.annotation_layers.api import AnnotationLayerRestApi from superset.annotation_layers.annotations.api import AnnotationRestApi + from superset.async_events.api import AsyncEventsRestApi from superset.cachekeys.api import CacheRestApi from superset.charts.api import ChartRestApi from superset.connectors.druid.views import ( @@ -201,6 +203,7 @@ class SupersetAppInitializer: # appbuilder.add_api(AnnotationRestApi) appbuilder.add_api(AnnotationLayerRestApi) + appbuilder.add_api(AsyncEventsRestApi) appbuilder.add_api(CacheRestApi) appbuilder.add_api(ChartRestApi) appbuilder.add_api(CssTemplateRestApi) @@ -498,6 +501,7 @@ class SupersetAppInitializer: self.configure_url_map_converters() self.configure_data_sources() self.configure_auth_provider() + self.configure_async_queries() # Hook that provides administrators a handle on the Flask APP # after initialization @@ -648,6 +652,10 @@ class SupersetAppInitializer: for ex in csrf_exempt_list: csrf.exempt(ex) + def configure_async_queries(self) -> None: + if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"): + async_query_manager.init_app(self.flask_app) + def register_blueprints(self) -> None: for bp in self.config["BLUEPRINTS"]: try: diff --git a/superset/async_events/__init__.py b/superset/async_events/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/async_events/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/async_events/api.py b/superset/async_events/api.py new file mode 100644 index 0000000000..61b85ac6f9 --- /dev/null +++ b/superset/async_events/api.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask import request, Response +from flask_appbuilder import expose +from flask_appbuilder.api import BaseApi, safe +from flask_appbuilder.security.decorators import permission_name, protect + +from superset.extensions import async_query_manager, event_logger +from superset.utils.async_query_manager import AsyncQueryTokenException + +logger = logging.getLogger(__name__) + + +class AsyncEventsRestApi(BaseApi): + resource_name = "async_event" + allow_browser_login = True + include_route_methods = { + "events", + } + + @expose("/", methods=["GET"]) + @event_logger.log_this + @protect() + @safe + @permission_name("list") + def events(self) -> Response: + """ + Reads off of the Redis async events stream, using the user's JWT token and + optional query params for last event received. + --- + get: + description: >- + Reads off of the Redis events stream, using the user's JWT token and + optional query params for last event received. + parameters: + - in: query + name: last_id + description: Last ID received by the client + schema: + type: string + responses: + 200: + description: Async event results + content: + application/json: + schema: + type: object + properties: + result: + type: array + items: + type: object + properties: + id: + type: string + channel_id: + type: string + job_id: + type: string + user_id: + type: integer + status: + type: string + msg: + type: string + cache_key: + type: string + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + try: + async_channel_id = async_query_manager.parse_jwt_from_request(request)[ + "channel" + ] + last_event_id = request.args.get("last_id") + events = async_query_manager.read_events(async_channel_id, last_event_id) + + except AsyncQueryTokenException: + return self.response_401() + + return self.response(200, result=events) diff --git a/superset/charts/api.py b/superset/charts/api.py index a3a2737aae..50245be3a0 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -33,10 +33,13 @@ from werkzeug.wsgi import FileWrapper from superset import is_feature_enabled, thumbnail_cache from superset.charts.commands.bulk_delete import BulkDeleteChartCommand from superset.charts.commands.create import CreateChartCommand +from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.delete import DeleteChartCommand from superset.charts.commands.exceptions import ( ChartBulkDeleteFailedError, ChartCreateFailedError, + ChartDataCacheLoadError, + ChartDataQueryFailedError, ChartDeleteFailedError, ChartForbiddenError, ChartInvalidError, @@ -50,7 +53,6 @@ from superset.charts.dao import ChartDAO from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter from superset.charts.schemas import ( CHART_SCHEMAS, - ChartDataQueryContextSchema, ChartPostSchema, ChartPutSchema, get_delete_ids_schema, @@ -67,7 +69,12 @@ from superset.exceptions import SupersetSecurityException from superset.extensions import event_logger from superset.models.slice import Slice from superset.tasks.thumbnails import cache_chart_thumbnail -from superset.utils.core import ChartDataResultFormat, json_int_dttm_ser +from superset.utils.async_query_manager import AsyncQueryTokenException +from superset.utils.core import ( + ChartDataResultFormat, + ChartDataResultType, + json_int_dttm_ser, +) from superset.utils.screenshots import ChartScreenshot from superset.utils.urls import get_url_path from superset.views.base_api import ( @@ -93,6 +100,8 @@ class ChartRestApi(BaseSupersetModelRestApi): RouteMethod.RELATED, "bulk_delete", # not using RouteMethod since locally defined "data", + "data_from_cache", + "viz_types", "favorite_status", } class_permission_name = "SliceModelView" @@ -448,6 +457,39 @@ class ChartRestApi(BaseSupersetModelRestApi): except ChartBulkDeleteFailedError as ex: return self.response_422(message=str(ex)) + def get_data_response( + self, command: ChartDataCommand, force_cached: bool = False + ) -> Response: + try: + result = command.run(force_cached=force_cached) + except ChartDataCacheLoadError as exc: + return self.response_422(message=exc.message) + except ChartDataQueryFailedError as exc: + return self.response_400(message=exc.message) + + result_format = result["query_context"].result_format + if result_format == ChartDataResultFormat.CSV: + # return the first result + data = result["queries"][0]["data"] + return CsvResponse( + data, + status=200, + headers=generate_download_headers("csv"), + mimetype="application/csv", + ) + + if result_format == ChartDataResultFormat.JSON: + response_data = simplejson.dumps( + {"result": result["queries"]}, + default=json_int_dttm_ser, + ignore_nan=True, + ) + resp = make_response(response_data, 200) + resp.headers["Content-Type"] = "application/json; charset=utf-8" + return resp + + return self.response_400(message=f"Unsupported result_format: {result_format}") + @expose("/data", methods=["POST"]) @protect() @safe @@ -478,8 +520,16 @@ class ChartRestApi(BaseSupersetModelRestApi): application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" + 202: + description: Async job details + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataAsyncResponseSchema" 400: $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' 500: $ref: '#/components/responses/500' """ @@ -490,47 +540,88 @@ class ChartRestApi(BaseSupersetModelRestApi): json_body = json.loads(request.form["form_data"]) else: return self.response_400(message="Request is not JSON") + try: - query_context = ChartDataQueryContextSchema().load(json_body) - except KeyError: - return self.response_400(message="Request is incorrect") + command = ChartDataCommand() + query_context = command.set_query_context(json_body) + command.validate() except ValidationError as error: return self.response_400( message=_("Request is incorrect: %(error)s", error=error.messages) ) - try: - query_context.raise_for_access() except SupersetSecurityException: return self.response_401() - payload = query_context.get_payload() - for query in payload: - if query.get("error"): - return self.response_400(message=f"Error: {query['error']}") - result_format = query_context.result_format - response = self.response_400( - message=f"Unsupported result_format: {result_format}" - ) + # TODO: support CSV, SQL query and other non-JSON types + if ( + is_feature_enabled("GLOBAL_ASYNC_QUERIES") + and query_context.result_format == ChartDataResultFormat.JSON + and query_context.result_type == ChartDataResultType.FULL + ): - if result_format == ChartDataResultFormat.CSV: - # return the first result - result = payload[0]["data"] - response = CsvResponse( - result, - status=200, - headers=generate_download_headers("csv"), - mimetype="application/csv", + try: + command.validate_async_request(request) + except AsyncQueryTokenException: + return self.response_401() + + result = command.run_async() + return self.response(202, **result) + + return self.get_data_response(command) + + @expose("/data/", methods=["GET"]) + @event_logger.log_this + @protect() + @safe + @statsd_metrics + def data_from_cache(self, cache_key: str) -> Response: + """ + Takes a query context cache key and returns payload + data response for the given query. + --- + get: + description: >- + Takes a query context cache key and returns payload data + response for the given query. + parameters: + - in: path + schema: + type: string + name: cache_key + responses: + 200: + description: Query result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + command = ChartDataCommand() + try: + cached_data = command.load_query_context_from_cache(cache_key) + command.set_query_context(cached_data) + command.validate() + except ChartDataCacheLoadError: + return self.response_404() + except ValidationError as error: + return self.response_400( + message=_("Request is incorrect: %(error)s", error=error.messages) ) + except SupersetSecurityException as exc: + logger.info(exc) + return self.response_401() - if result_format == ChartDataResultFormat.JSON: - response_data = simplejson.dumps( - {"result": payload}, default=json_int_dttm_ser, ignore_nan=True - ) - resp = make_response(response_data, 200) - resp.headers["Content-Type"] = "application/json; charset=utf-8" - response = resp - - return response + return self.get_data_response(command, True) @expose("//cache_screenshot/", methods=["GET"]) @protect() diff --git a/superset/charts/commands/data.py b/superset/charts/commands/data.py new file mode 100644 index 0000000000..275a723d60 --- /dev/null +++ b/superset/charts/commands/data.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict + +from flask import Request +from marshmallow import ValidationError + +from superset import cache +from superset.charts.commands.exceptions import ( + ChartDataCacheLoadError, + ChartDataQueryFailedError, +) +from superset.charts.schemas import ChartDataQueryContextSchema +from superset.commands.base import BaseCommand +from superset.common.query_context import QueryContext +from superset.exceptions import CacheLoadError +from superset.extensions import async_query_manager +from superset.tasks.async_queries import load_chart_data_into_cache + +logger = logging.getLogger(__name__) + + +class ChartDataCommand(BaseCommand): + def __init__(self) -> None: + self._form_data: Dict[str, Any] + self._query_context: QueryContext + self._async_channel_id: str + + def run(self, **kwargs: Any) -> Dict[str, Any]: + # caching is handled in query_context.get_df_payload + # (also evals `force` property) + cache_query_context = kwargs.get("cache", False) + force_cached = kwargs.get("force_cached", False) + try: + payload = self._query_context.get_payload( + cache_query_context=cache_query_context, force_cached=force_cached + ) + except CacheLoadError as exc: + raise ChartDataCacheLoadError(exc.message) + + # TODO: QueryContext should support SIP-40 style errors + for query in payload["queries"]: + if query.get("error"): + raise ChartDataQueryFailedError(f"Error: {query['error']}") + + return_value = { + "query_context": self._query_context, + "queries": payload["queries"], + } + if cache_query_context: + return_value.update(cache_key=payload["cache_key"]) + + return return_value + + def run_async(self) -> Dict[str, Any]: + job_metadata = async_query_manager.init_job(self._async_channel_id) + load_chart_data_into_cache.delay(job_metadata, self._form_data) + + return job_metadata + + def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: + self._form_data = form_data + try: + self._query_context = ChartDataQueryContextSchema().load(self._form_data) + except KeyError: + raise ValidationError("Request is incorrect") + except ValidationError as error: + raise error + + return self._query_context + + def validate(self) -> None: + self._query_context.raise_for_access() + + def validate_async_request(self, request: Request) -> None: + jwt_data = async_query_manager.parse_jwt_from_request(request) + self._async_channel_id = jwt_data["channel"] + + def load_query_context_from_cache( # pylint: disable=no-self-use + self, cache_key: str + ) -> Dict[str, Any]: + cache_value = cache.get(cache_key) + if not cache_value: + raise ChartDataCacheLoadError("Cached data not found") + + return cache_value["data"] diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py index 51b5ca84f0..d1fa375c02 100644 --- a/superset/charts/commands/exceptions.py +++ b/superset/charts/commands/exceptions.py @@ -90,6 +90,14 @@ class ChartBulkDeleteFailedError(DeleteFailedError): message = _("Charts could not be deleted.") +class ChartDataQueryFailedError(CommandException): + pass + + +class ChartDataCacheLoadError(CommandException): + pass + + class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError): message = _("There are associated alerts or reports") diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 347189a13e..5de166c86e 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -1105,6 +1105,18 @@ class ChartDataResponseSchema(Schema): ) +class ChartDataAsyncResponseSchema(Schema): + channel_id = fields.String( + description="Unique session async channel ID", allow_none=False, + ) + job_id = fields.String(description="Unique async job ID", allow_none=False,) + user_id = fields.String(description="Requesting user ID", allow_none=True,) + status = fields.String(description="Status value for async job", allow_none=False,) + result_url = fields.String( + description="Unique result URL for fetching async query data", allow_none=False, + ) + + class ChartFavStarResponseResult(Schema): id = fields.Integer(description="The Chart id") value = fields.Boolean(description="The FaveStar value") @@ -1130,6 +1142,7 @@ class ImportV1ChartSchema(Schema): CHART_SCHEMAS = ( ChartDataQueryContextSchema, ChartDataResponseSchema, + ChartDataAsyncResponseSchema, # TODO: These should optimally be included in the QueryContext schema as an `anyOf` # in ChartDataPostPricessingOperation.options, but since `anyOf` is not # by Marshmallow<3, this is not currently possible. diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 25113c36ae..a7900cf667 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -17,7 +17,7 @@ import copy import logging import math -from datetime import datetime, timedelta +from datetime import timedelta from typing import Any, cast, ClassVar, Dict, List, Optional, Union import numpy as np @@ -30,13 +30,17 @@ from superset.charts.dao import ChartDAO from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry -from superset.exceptions import QueryObjectValidationError, SupersetException +from superset.exceptions import ( + CacheLoadError, + QueryObjectValidationError, + SupersetException, +) from superset.extensions import cache_manager, security_manager from superset.stats_logger import BaseStatsLogger from superset.utils import core as utils +from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import DTTM_ALIAS from superset.views.utils import get_viz -from superset.viz import set_and_log_cache config = app.config stats_logger: BaseStatsLogger = config["STATS_LOGGER"] @@ -78,6 +82,13 @@ class QueryContext: self.custom_cache_timeout = custom_cache_timeout self.result_type = result_type or utils.ChartDataResultType.FULL self.result_format = result_format or utils.ChartDataResultFormat.JSON + self.cache_values = { + "datasource": datasource, + "queries": queries, + "force": force, + "result_type": result_type, + "result_format": result_format, + } def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: """Returns a pandas dataframe based on the query object""" @@ -142,8 +153,11 @@ class QueryContext: return df.to_dict(orient="records") - def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]: + def get_single_payload( + self, query_obj: QueryObject, **kwargs: Any + ) -> Dict[str, Any]: """Returns a payload of metadata and data""" + force_cached = kwargs.get("force_cached", False) if self.result_type == utils.ChartDataResultType.QUERY: return { "query": self.datasource.get_query_str(query_obj.to_dict()), @@ -159,8 +173,7 @@ class QueryContext: query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"]) query_obj.row_offset = 0 query_obj.columns = [o.column_name for o in self.datasource.columns] - payload = self.get_df_payload(query_obj) - + payload = self.get_df_payload(query_obj, force_cached=force_cached) df = payload["df"] status = payload["status"] if status != utils.QueryStatus.FAILED: @@ -186,9 +199,28 @@ class QueryContext: return {"data": payload["data"]} return payload - def get_payload(self) -> List[Dict[str, Any]]: - """Get all the payloads from the QueryObjects""" - return [self.get_single_payload(query_object) for query_object in self.queries] + def get_payload(self, **kwargs: Any) -> Dict[str, Any]: + cache_query_context = kwargs.get("cache_query_context", False) + force_cached = kwargs.get("force_cached", False) + + # Get all the payloads from the QueryObjects + query_results = [ + self.get_single_payload(query_object, force_cached=force_cached) + for query_object in self.queries + ] + return_value = {"queries": query_results} + + if cache_query_context: + cache_key = self.cache_key() + set_and_log_cache( + cache_manager.cache, + cache_key, + {"data": self.cache_values}, + self.cache_timeout, + ) + return_value["cache_key"] = cache_key # type: ignore + + return return_value @property def cache_timeout(self) -> int: @@ -203,7 +235,22 @@ class QueryContext: return self.datasource.database.cache_timeout return config["CACHE_DEFAULT_TIMEOUT"] - def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + def cache_key(self, **extra: Any) -> str: + """ + The QueryContext cache key is made out of the key/values from + self.cached_values, plus any other key/values in `extra`. It includes only data + required to rehydrate a QueryContext object. + """ + key_prefix = "qc-" + cache_dict = self.cache_values.copy() + cache_dict.update(extra) + + return generate_cache_key(cache_dict, key_prefix) + + def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + """ + Returns a QueryObject cache key for objects in self.queries + """ extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict()) cache_key = ( @@ -215,7 +262,7 @@ class QueryContext: and self.datasource.is_rls_supported else [], changed_on=self.datasource.changed_on, - **kwargs + **kwargs, ) if query_obj else None @@ -298,12 +345,12 @@ class QueryContext: self, query_obj: QueryObject, **kwargs: Any ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" - cache_key = self.cache_key(query_obj, **kwargs) + force_cached = kwargs.get("force_cached", False) + cache_key = self.query_cache_key(query_obj) logger.info("Cache key: %s", cache_key) is_loaded = False stacktrace = None df = pd.DataFrame() - cached_dttm = datetime.utcnow().isoformat().split(".")[0] cache_value = None status = None query = "" @@ -327,6 +374,12 @@ class QueryContext: ) logger.info("Serving from cache") + if force_cached and not is_loaded: + logger.warning( + "force_cached (QueryContext): value not found for key %s", cache_key + ) + raise CacheLoadError("Error loading data from cache") + if query_obj and not is_loaded: try: invalid_columns = [ @@ -367,13 +420,11 @@ class QueryContext: if is_loaded and cache_key and status != utils.QueryStatus.FAILED: set_and_log_cache( - cache_key=cache_key, - df=df, - query=query, - annotation_data=annotation_data, - cached_dttm=cached_dttm, - cache_timeout=self.cache_timeout, - datasource_uid=self.datasource.uid, + cache_manager.data_cache, + cache_key, + {"df": df, "query": query, "annotation_data": annotation_data}, + self.cache_timeout, + self.datasource.uid, ) return { "cache_key": cache_key, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 6eb1231f2e..edcc325fed 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -277,7 +277,8 @@ class QueryObject: :param df: DataFrame returned from database model. :return: new DataFrame to which all post processing operations have been applied - :raises ChartDataValidationError: If the post processing operation in incorrect + :raises QueryObjectValidationError: If the post processing operation + is incorrect """ for post_process in self.post_processing: operation = post_process.get("operation") diff --git a/superset/config.py b/superset/config.py index 22d3cc5262..f20cbff716 100644 --- a/superset/config.py +++ b/superset/config.py @@ -327,6 +327,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { "DISPLAY_MARKDOWN_HTML": True, # When True, this escapes HTML (rather than rendering it) in Markdown components "ESCAPE_MARKDOWN_HTML": False, + "GLOBAL_ASYNC_QUERIES": False, "VERSIONED_EXPORT": False, # Note that: RowLevelSecurityFilter is only given by default to the Admin role # and the Admin Role does have the all_datasources security permission. @@ -406,6 +407,9 @@ CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"} # Cache for datasource metadata and query results DATA_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"} +# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation +STORE_CACHE_KEYS_IN_METADATA_DB = False + # CORS Options ENABLE_CORS = False CORS_OPTIONS: Dict[Any, Any] = {} @@ -965,6 +969,23 @@ SIP_15_TOAST_MESSAGE = ( # conventions and such. You can find examples in the tests. SQLA_TABLE_MUTATOR = lambda table: table +# Global async query config options. +# Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled. +GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = { + "port": 6379, + "host": "127.0.0.1", + "password": "", + "db": 0, +} +GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-" +GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000 +GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 +GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" +GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False +GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" +GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling" +GLOBAL_ASYNC_QUERIES_POLLING_DELAY = 500 + if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not necessarily in pythonpath; useful # for case where app is being executed via pex. diff --git a/superset/exceptions.py b/superset/exceptions.py index fd95a59e2c..52c19c3a02 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from flask_babel import gettext as _ @@ -65,6 +65,14 @@ class SupersetSecurityException(SupersetException): self.payload = payload +class SupersetVizException(SupersetException): + status = 400 + + def __init__(self, errors: List[SupersetError]) -> None: + super(SupersetVizException, self).__init__(str(errors)) + self.errors = errors + + class NoDataException(SupersetException): status = 400 @@ -93,6 +101,10 @@ class QueryObjectValidationError(SupersetException): status = 400 +class CacheLoadError(SupersetException): + status = 404 + + class DashboardImportException(SupersetException): pass diff --git a/superset/extensions.py b/superset/extensions.py index 7011bb237d..8f5bc6d5a3 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -27,6 +27,7 @@ from flask_talisman import Talisman from flask_wtf.csrf import CSRFProtect from werkzeug.local import LocalProxy +from superset.utils.async_query_manager import AsyncQueryManager from superset.utils.cache_manager import CacheManager from superset.utils.feature_flag_manager import FeatureFlagManager from superset.utils.machine_auth import MachineAuthProviderFactory @@ -97,6 +98,7 @@ class UIManifestProcessor: APP_DIR = os.path.dirname(__file__) appbuilder = AppBuilder(update_perms=False) +async_query_manager = AsyncQueryManager() cache_manager = CacheManager() celery_app = celery.Celery() csrf = CSRFProtect() diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py new file mode 100644 index 0000000000..b8db82b9af --- /dev/null +++ b/superset/tasks/async_queries.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, cast, Dict, Optional + +from flask import current_app + +from superset import app +from superset.exceptions import SupersetVizException +from superset.extensions import async_query_manager, cache_manager, celery_app +from superset.utils.cache import generate_cache_key, set_and_log_cache +from superset.views.utils import get_datasource_info, get_viz + +logger = logging.getLogger(__name__) +query_timeout = current_app.config[ + "SQLLAB_ASYNC_TIME_LIMIT_SEC" +] # TODO: new config key + + +@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) +def load_chart_data_into_cache( + job_metadata: Dict[str, Any], form_data: Dict[str, Any], +) -> None: + from superset.charts.commands.data import ( + ChartDataCommand, + ) # load here due to circular imports + + with app.app_context(): # type: ignore + try: + command = ChartDataCommand() + command.set_query_context(form_data) + result = command.run(cache=True) + cache_key = result["cache_key"] + result_url = f"/api/v1/chart/data/{cache_key}" + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, + ) + except Exception as exc: + # TODO: QueryContext should support SIP-40 style errors + error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member + errors = [{"message": error}] + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_ERROR, errors=errors + ) + raise exc + + return None + + +@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) +def load_explore_json_into_cache( + job_metadata: Dict[str, Any], + form_data: Dict[str, Any], + response_type: Optional[str] = None, + force: bool = False, +) -> None: + with app.app_context(): # type: ignore + cache_key_prefix = "ejr-" # ejr: explore_json request + try: + datasource_id, datasource_type = get_datasource_info(None, None, form_data) + + viz_obj = get_viz( + datasource_type=cast(str, datasource_type), + datasource_id=datasource_id, + form_data=form_data, + force=force, + ) + # run query & cache results + payload = viz_obj.get_payload() + if viz_obj.has_error(payload): + raise SupersetVizException(errors=payload["errors"]) + + # cache form_data for async retrieval + cache_value = {"form_data": form_data, "response_type": response_type} + cache_key = generate_cache_key(cache_value, cache_key_prefix) + set_and_log_cache(cache_manager.cache, cache_key, cache_value) + result_url = f"/superset/explore_json/data/{cache_key}" + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, + ) + except Exception as exc: + if isinstance(exc, SupersetVizException): + errors = exc.errors # pylint: disable=no-member + else: + error = ( + exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member + ) + errors = [error] + + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_ERROR, errors=errors + ) + raise exc + + return None diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py new file mode 100644 index 0000000000..42d2c130bd --- /dev/null +++ b/superset/utils/async_query_manager.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import jwt +import redis +from flask import Flask, Request, Response, session + +logger = logging.getLogger(__name__) + + +class AsyncQueryTokenException(Exception): + pass + + +class AsyncQueryJobException(Exception): + pass + + +def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]: + return { + "channel_id": channel_id, + "job_id": job_id, + "user_id": session.get("user_id"), + "status": kwargs.get("status"), + "errors": kwargs.get("errors", []), + "result_url": kwargs.get("result_url"), + } + + +def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]: + event_id = event_data[0] + event_payload = event_data[1]["data"] + return {"id": event_id, **json.loads(event_payload)} + + +def increment_id(redis_id: str) -> str: + # redis stream IDs are in this format: '1607477697866-0' + try: + prefix, last = redis_id[:-1], int(redis_id[-1]) + return prefix + str(last + 1) + except Exception: # pylint: disable=broad-except + return redis_id + + +class AsyncQueryManager: + MAX_EVENT_COUNT = 100 + STATUS_PENDING = "pending" + STATUS_RUNNING = "running" + STATUS_ERROR = "error" + STATUS_DONE = "done" + + def __init__(self) -> None: + super().__init__() + self._redis: redis.Redis + self._stream_prefix: str = "" + self._stream_limit: Optional[int] + self._stream_limit_firehose: Optional[int] + self._jwt_cookie_name: str + self._jwt_cookie_secure: bool = False + self._jwt_secret: str + + def init_app(self, app: Flask) -> None: + config = app.config + if ( + config["CACHE_CONFIG"]["CACHE_TYPE"] == "null" + or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null" + ): + raise Exception( + """ + Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured + and non-null in order to enable async queries + """ + ) + + if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: + raise AsyncQueryTokenException( + "Please provide a JWT secret at least 32 bytes long" + ) + + self._redis = redis.Redis( # type: ignore + **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True + ) + self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] + self._stream_limit_firehose = config[ + "GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE" + ] + self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"] + self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"] + self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"] + + @app.after_request + def validate_session( # pylint: disable=unused-variable + response: Response, + ) -> Response: + reset_token = False + user_id = session["user_id"] if "user_id" in session else None + + if "async_channel_id" not in session or "async_user_id" not in session: + reset_token = True + elif user_id != session["async_user_id"]: + reset_token = True + + if reset_token: + async_channel_id = str(uuid.uuid4()) + session["async_channel_id"] = async_channel_id + session["async_user_id"] = user_id + + sub = str(user_id) if user_id else None + token = self.generate_jwt({"channel": async_channel_id, "sub": sub}) + + response.set_cookie( + self._jwt_cookie_name, + value=token, + httponly=True, + secure=self._jwt_cookie_secure, + # max_age=max_age or config.cookie_max_age, + # domain=config.cookie_domain, + # path=config.access_cookie_path, + # samesite=config.cookie_samesite + ) + + return response + + def generate_jwt(self, data: Dict[str, Any]) -> str: + encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256") + return encoded_jwt.decode("utf-8") + + def parse_jwt(self, token: str) -> Dict[str, Any]: + data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + return data + + def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]: + token = request.cookies.get(self._jwt_cookie_name) + if not token: + raise AsyncQueryTokenException("Token not preset") + + try: + return self.parse_jwt(token) + except Exception as exc: + logger.warning(exc) + raise AsyncQueryTokenException("Failed to parse token") + + def init_job(self, channel_id: str) -> Dict[str, Any]: + job_id = str(uuid.uuid4()) + return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING) + + def read_events( + self, channel: str, last_id: Optional[str] + ) -> List[Optional[Dict[str, Any]]]: + stream_name = f"{self._stream_prefix}{channel}" + start_id = increment_id(last_id) if last_id else "-" + results = self._redis.xrange( # type: ignore + stream_name, start_id, "+", self.MAX_EVENT_COUNT + ) + return [] if not results else list(map(parse_event, results)) + + def update_job( + self, job_metadata: Dict[str, Any], status: str, **kwargs: Any + ) -> None: + if "channel_id" not in job_metadata: + raise AsyncQueryJobException("No channel ID specified") + + if "job_id" not in job_metadata: + raise AsyncQueryJobException("No job ID specified") + + updates = {"status": status, **kwargs} + event_data = {"data": json.dumps({**job_metadata, **updates})} + + full_stream_name = f"{self._stream_prefix}full" + scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}" + + logger.debug("********** logging event data to stream %s", scoped_stream_name) + logger.debug(event_data) + + self._redis.xadd( # type: ignore + scoped_stream_name, event_data, "*", self._stream_limit + ) + self._redis.xadd( # type: ignore + full_stream_name, event_data, "*", self._stream_limit_firehose + ) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index f0b24b26ae..729c316866 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -14,16 +14,65 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import hashlib +import json import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union from flask import current_app as app, request from flask_caching import Cache from werkzeug.wrappers.etag import ETagResponseMixin +from superset import db from superset.extensions import cache_manager +from superset.models.cache import CacheKey +from superset.stats_logger import BaseStatsLogger +from superset.utils.core import json_int_dttm_ser + +config = app.config # type: ignore +stats_logger: BaseStatsLogger = config["STATS_LOGGER"] +logger = logging.getLogger(__name__) + +# TODO: DRY up cache key code +def json_dumps(obj: Any, sort_keys: bool = False) -> str: + return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys) + + +def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str: + json_data = json_dumps(values_dict, sort_keys=True) + hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest() + return f"{key_prefix}{hash_str}" + + +def set_and_log_cache( + cache_instance: Cache, + cache_key: str, + cache_value: Dict[str, Any], + cache_timeout: Optional[int] = None, + datasource_uid: Optional[str] = None, +) -> None: + timeout = cache_timeout if cache_timeout else config["CACHE_DEFAULT_TIMEOUT"] + try: + dttm = datetime.utcnow().isoformat().split(".")[0] + value = {**cache_value, "dttm": dttm} + cache_instance.set(cache_key, value, timeout=timeout) + stats_logger.incr("set_cache_key") + + if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]: + ck = CacheKey( + cache_key=cache_key, + cache_timeout=cache_timeout, + datasource_uid=datasource_uid, + ) + db.session.add(ck) + except Exception as ex: # pylint: disable=broad-except + # cache.set call can fail if the backend is down or if + # the key is too large or whatever other reasons + logger.warning("Could not cache key %s", cache_key) + logger.exception(ex) + # If a user sets `max_age` to 0, for long the browser should cache the # resource? Flask-Caching will cache forever, but for the HTTP header we need diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 84ee9b7ffb..5d4c2013dd 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -238,7 +238,7 @@ def pivot( # pylint: disable=too-many-arguments Default to 'All'. :param flatten_columns: Convert column names to strings :return: A pivot table - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ if not index: raise QueryObjectValidationError( @@ -293,7 +293,7 @@ def aggregate( :param groupby: columns to aggregate :param aggregates: A mapping from metric column to the function used to aggregate values. - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ aggregates = aggregates or {} aggregate_funcs = _get_aggregate_funcs(df, aggregates) @@ -313,7 +313,7 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: :param columns: columns by by which to sort. The key specifies the column name, value specifies if sorting in ascending order. :return: Sorted DataFrame - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ return df.sort_values(by=list(columns.keys()), ascending=list(columns.values())) @@ -348,7 +348,7 @@ def rolling( # pylint: disable=too-many-arguments :param min_periods: The minimum amount of periods required for a row to be included in the result set. :return: DataFrame with the rolling columns - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ rolling_type_options = rolling_type_options or {} df_rolling = df[columns.keys()] @@ -408,7 +408,7 @@ def select( For instance, `{'y': 'y2'}` will rename the column `y` to `y2`. :return: Subset of columns in original DataFrame - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ df_select = df.copy(deep=False) if columns: @@ -433,7 +433,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame unchanged. :param periods: periods to shift for calculating difference. :return: DataFrame with diffed columns - :raises ChartDataValidationError: If the request in incorrect + :raises QueryObjectValidationError: If the request in incorrect """ df_diff = df[columns.keys()] df_diff = df_diff.diff(periods=periods) diff --git a/superset/views/api.py b/superset/views/api.py index a5090b31c9..1b19455126 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -44,7 +44,8 @@ class Api(BaseSupersetView): """ query_context = QueryContext(**json.loads(request.form["query_context"])) query_context.raise_for_access() - payload_json = query_context.get_payload() + result = query_context.get_payload() + payload_json = result["queries"] return json.dumps( payload_json, default=utils.json_int_dttm_ser, ignore_nan=True ) diff --git a/superset/views/base.py b/superset/views/base.py index 2f87a81f97..9ed65222bc 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -77,6 +77,8 @@ FRONTEND_CONF_KEYS = ( "SUPERSET_WEBSERVER_DOMAINS", "SQLLAB_SAVE_WARNING_MESSAGE", "DISPLAY_MAX_ROW", + "GLOBAL_ASYNC_QUERIES_TRANSPORT", + "GLOBAL_ASYNC_QUERIES_POLLING_DELAY", ) logger = logging.getLogger(__name__) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 0f05f07ee1..2e0d3b2369 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -126,6 +126,7 @@ class BaseSupersetModelRestApi(ModelRestApi): method_permission_name = { "bulk_delete": "delete", "data": "list", + "data_from_cache": "list", "delete": "delete", "distinct": "list", "export": "mulexport", diff --git a/superset/views/core.py b/superset/views/core.py index 4e062da974..63591b73ef 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -28,7 +28,11 @@ import simplejson as json from flask import abort, flash, g, Markup, redirect, render_template, request, Response from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface -from flask_appbuilder.security.decorators import has_access, has_access_api +from flask_appbuilder.security.decorators import ( + has_access, + has_access_api, + permission_name, +) from flask_appbuilder.security.sqla import models as ab_models from flask_babel import gettext as __, lazy_gettext as _ from jinja2.exceptions import TemplateError @@ -66,6 +70,7 @@ from superset.dashboards.dao import DashboardDAO from superset.databases.dao import DatabaseDAO from superset.databases.filters import DatabaseFilter from superset.exceptions import ( + CacheLoadError, CertificateException, DatabaseNotFound, SerializationError, @@ -73,6 +78,7 @@ from superset.exceptions import ( SupersetSecurityException, SupersetTimeoutException, ) +from superset.extensions import async_query_manager, cache_manager from superset.jinja_context import get_template_processor from superset.models.core import Database, FavStar, Log from superset.models.dashboard import Dashboard @@ -87,8 +93,10 @@ from superset.security.analytics_db_safety import ( ) from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sql_validators import get_validator_by_name +from superset.tasks.async_queries import load_explore_json_into_cache from superset.typing import FlaskResponse from superset.utils import core as utils +from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.cache import etag_cache from superset.utils.dates import now_as_float from superset.views.base import ( @@ -113,6 +121,7 @@ from superset.views.utils import ( apply_display_max_row_limit, bootstrap_user_data, check_datasource_perms, + check_explore_cache_perms, check_slice_perms, get_cta_schema_name, get_dashboard_extra_filters, @@ -484,6 +493,43 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods payload = viz_obj.get_payload() return data_payload_response(*viz_obj.payload_json_and_has_error(payload)) + @event_logger.log_this + @api + @has_access_api + @handle_api_exception + @permission_name("explore_json") + @expose("/explore_json/data/", methods=["GET"]) + @etag_cache(check_perms=check_explore_cache_perms) + def explore_json_data(self, cache_key: str) -> FlaskResponse: + """Serves cached result data for async explore_json calls + + `self.generate_json` receives this input and returns different + payloads based on the request args in the first block + + TODO: form_data should not be loaded twice from cache + (also loaded in `check_explore_cache_perms`) + """ + try: + cached = cache_manager.cache.get(cache_key) + if not cached: + raise CacheLoadError("Cached data not found") + + form_data = cached.get("form_data") + response_type = cached.get("response_type") + + datasource_id, datasource_type = get_datasource_info(None, None, form_data) + + viz_obj = get_viz( + datasource_type=cast(str, datasource_type), + datasource_id=datasource_id, + form_data=form_data, + force_cached=True, + ) + + return self.generate_json(viz_obj, response_type) + except SupersetException as ex: + return json_error_response(utils.error_msg_from_exception(ex), 400) + EXPLORE_JSON_METHODS = ["POST"] if not is_feature_enabled("ENABLE_EXPLORE_JSON_CSRF_PROTECTION"): EXPLORE_JSON_METHODS.append("GET") @@ -528,11 +574,31 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods datasource_id, datasource_type, form_data ) + force = request.args.get("force") == "true" + + # TODO: support CSV, SQL query and other non-JSON types + if ( + is_feature_enabled("GLOBAL_ASYNC_QUERIES") + and response_type == utils.ChartDataResultFormat.JSON + ): + try: + async_channel_id = async_query_manager.parse_jwt_from_request( + request + )["channel"] + job_metadata = async_query_manager.init_job(async_channel_id) + load_explore_json_into_cache.delay( + job_metadata, form_data, response_type, force + ) + except AsyncQueryTokenException: + return json_error_response("Not authorized", 401) + + return json_success(json.dumps(job_metadata), status=202) + viz_obj = get_viz( datasource_type=cast(str, datasource_type), datasource_id=datasource_id, form_data=form_data, - force=request.args.get("force") == "true", + force=force, ) return self.generate_json(viz_obj, response_type) diff --git a/superset/views/utils.py b/superset/views/utils.py index 08de168003..28104aa86a 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -34,10 +34,12 @@ from superset import app, dataframe, db, is_feature_enabled, result_set from superset.connectors.connector_registry import ConnectorRegistry from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( + CacheLoadError, SerializationError, SupersetException, SupersetSecurityException, ) +from superset.extensions import cache_manager from superset.legacy import update_time_range from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -108,13 +110,19 @@ def get_permissions( def get_viz( - form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False + form_data: FormData, + datasource_type: str, + datasource_id: int, + force: bool = False, + force_cached: bool = False, ) -> BaseViz: viz_type = form_data.get("viz_type", "table") datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session ) - viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force) + viz_obj = viz.viz_types[viz_type]( + datasource, form_data=form_data, force=force, force_cached=force_cached + ) return viz_obj @@ -422,10 +430,26 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool: return obj and user in obj.owners +def check_explore_cache_perms(_self: Any, cache_key: str) -> None: + """ + Loads async explore_json request data from cache and performs access check + + :param _self: the Superset view instance + :param cache_key: the cache key passed into /explore_json/data/ + :raises SupersetSecurityException: If the user cannot access the resource + """ + cached = cache_manager.cache.get(cache_key) + if not cached: + raise CacheLoadError("Cached data not found") + + check_datasource_perms(_self, form_data=cached["form_data"]) + + def check_datasource_perms( _self: Any, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None, + **kwargs: Any ) -> None: """ Check if user can access a cached response from explore_json. @@ -438,7 +462,7 @@ def check_datasource_perms( :raises SupersetSecurityException: If the user cannot access the resource """ - form_data = get_form_data()[0] + form_data = kwargs["form_data"] if "form_data" in kwargs else get_form_data()[0] try: datasource_id, datasource_type = get_datasource_info( diff --git a/superset/viz.py b/superset/viz.py index de5d597f73..9b7d46ac40 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -38,6 +38,7 @@ from typing import ( Optional, Set, Tuple, + Type, TYPE_CHECKING, Union, ) @@ -57,6 +58,7 @@ from superset import app, db, is_feature_enabled from superset.constants import NULL_STRING from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( + CacheLoadError, NullValueException, QueryObjectValidationError, SpatialException, @@ -66,6 +68,7 @@ from superset.models.cache import CacheKey from superset.models.helpers import QueryResult from superset.typing import QueryObjectDict, VizData, VizPayload from superset.utils import core as utils +from superset.utils.cache import set_and_log_cache from superset.utils.core import ( DTTM_ALIAS, JS_MAX_INTEGER, @@ -97,37 +100,6 @@ METRIC_KEYS = [ ] -def set_and_log_cache( - cache_key: str, - df: pd.DataFrame, - query: str, - cached_dttm: str, - cache_timeout: int, - datasource_uid: Optional[str], - annotation_data: Optional[Dict[str, Any]] = None, -) -> None: - try: - cache_value = dict( - dttm=cached_dttm, df=df, query=query, annotation_data=annotation_data or {} - ) - stats_logger.incr("set_cache_key") - cache_manager.data_cache.set(cache_key, cache_value, timeout=cache_timeout) - - if datasource_uid: - ck = CacheKey( - cache_key=cache_key, - cache_timeout=cache_timeout, - datasource_uid=datasource_uid, - ) - db.session.add(ck) - except Exception as ex: - # cache.set call can fail if the backend is down or if - # the key is too large or whatever other reasons - logger.warning("Could not cache key {}".format(cache_key)) - logger.exception(ex) - cache_manager.data_cache.delete(cache_key) - - class BaseViz: """All visualizations derive this base class""" @@ -144,6 +116,7 @@ class BaseViz: datasource: "BaseDatasource", form_data: Dict[str, Any], force: bool = False, + force_cached: bool = False, ) -> None: if not datasource: raise QueryObjectValidationError(_("Viz is missing a datasource")) @@ -164,6 +137,7 @@ class BaseViz: self.results: Optional[QueryResult] = None self.errors: List[Dict[str, Any]] = [] self.force = force + self._force_cached = force_cached self.from_dttm: Optional[datetime] = None self.to_dttm: Optional[datetime] = None @@ -180,6 +154,10 @@ class BaseViz: self.applied_filters: List[Dict[str, str]] = [] self.rejected_filters: List[Dict[str, str]] = [] + @property + def force_cached(self) -> bool: + return self._force_cached + def process_metrics(self) -> None: # metrics in TableViz is order sensitive, so metric_dict should be # OrderedDict @@ -272,7 +250,7 @@ class BaseViz: "columns": [o.column_name for o in self.datasource.columns], } ) - df = self.get_df(query_obj) + df = self.get_df_payload(query_obj)["df"] # leverage caching logic return df.to_dict(orient="records") def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: @@ -518,7 +496,6 @@ class BaseViz: is_loaded = False stacktrace = None df = None - cached_dttm = datetime.utcnow().isoformat().split(".")[0] if cache_key and cache_manager.data_cache and not self.force: cache_value = cache_manager.data_cache.get(cache_key) if cache_value: @@ -539,6 +516,11 @@ class BaseViz: logger.info("Serving from cache") if query_obj and not is_loaded: + if self.force_cached: + logger.warning( + f"force_cached (viz.py): value not found for cache key {cache_key}" + ) + raise CacheLoadError(_("Cached value not found")) try: invalid_columns = [ col @@ -590,12 +572,11 @@ class BaseViz: if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED: set_and_log_cache( - cache_key=cache_key, - df=df, - query=self.query, - cached_dttm=cached_dttm, - cache_timeout=self.cache_timeout, - datasource_uid=self.datasource.uid, + cache_manager.data_cache, + cache_key, + {"df": df, "query": self.query}, + self.cache_timeout, + self.datasource.uid, ) return { "cache_key": self._any_cache_key, @@ -618,13 +599,15 @@ class BaseViz: obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) - def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: - has_error = ( + def has_error(self, payload: VizPayload) -> bool: + return ( payload.get("status") == utils.QueryStatus.FAILED or payload.get("error") is not None or bool(payload.get("errors")) ) - return self.json_dumps(payload), has_error + + def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: + return self.json_dumps(payload), self.has_error(payload) @property def data(self) -> Dict[str, Any]: @@ -638,7 +621,7 @@ class BaseViz: return content def get_csv(self) -> Optional[str]: - df = self.get_df() + df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) return df.to_csv(index=include_index, **config["CSV_EXPORT"]) @@ -1721,7 +1704,7 @@ class SunburstViz(BaseViz): def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None - fd = self.form_data + fd = copy.deepcopy(self.form_data) cols = fd.get("groupby") or [] cols.extend(["m1", "m2"]) metric = utils.get_metric_name(fd["metric"]) @@ -2983,12 +2966,14 @@ class PartitionViz(NVD3TimeSeriesViz): return self.nest_values(levels) +def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]: + return set(cls.__subclasses__()).union( + [sc for c in cls.__subclasses__() for sc in get_subclasses(c)] + ) + + viz_types = { o.viz_type: o - for o in globals().values() - if ( - inspect.isclass(o) - and issubclass(o, BaseViz) - and o.viz_type not in config["VIZ_TYPE_DENYLIST"] - ) + for o in get_subclasses(BaseViz) + if o.viz_type not in config["VIZ_TYPE_DENYLIST"] } diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 600f44141a..798ce42f27 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -57,13 +57,13 @@ from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult from superset.typing import QueryObjectDict, VizData, VizPayload from superset.utils import core as utils +from superset.utils.cache import set_and_log_cache from superset.utils.core import ( DTTM_ALIAS, JS_MAX_INTEGER, merge_extra_filters, to_adhoc, ) -from superset.viz import set_and_log_cache if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -518,10 +518,9 @@ class BaseViz: if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED: set_and_log_cache( + cache_manager.data_cache, cache_key, - df, - self.query, - cached_dttm, + {"df": df, "query": self.query}, self.cache_timeout, self.datasource.uid, ) diff --git a/tests/async_events/__init__.py b/tests/async_events/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/async_events/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/async_events/api_tests.py b/tests/async_events/api_tests.py new file mode 100644 index 0000000000..04d838b97b --- /dev/null +++ b/tests/async_events/api_tests.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +from typing import Optional +from unittest import mock + +from superset.extensions import async_query_manager +from tests.base_tests import SupersetTestCase +from tests.test_app import app + + +class TestAsyncEventApi(SupersetTestCase): + UUID = "943c920-32a5-412a-977d-b8e47d36f5a4" + + def fetch_events(self, last_id: Optional[str] = None): + base_uri = "api/v1/async_event/" + uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri + return self.client.get(uri) + + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events(self, mock_uuid4): + async_query_manager.init_app(app) + self.login(username="admin") + with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + rv = self.fetch_events() + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID + mock_xrange.assert_called_with(channel_id, "-", "+", 100) + self.assertEqual(response, {"result": []}) + + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events_last_id(self, mock_uuid4): + async_query_manager.init_app(app) + self.login(username="admin") + with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + rv = self.fetch_events("1607471525180-0") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID + mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100) + self.assertEqual(response, {"result": []}) + + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events_results(self, mock_uuid4): + async_query_manager.init_app(app) + self.login(username="admin") + with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + mock_xrange.return_value = [ + ( + "1607477697866-0", + { + "data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463"}' + }, + ), + ( + "1607477697993-0", + { + "data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d"}' + }, + ), + ] + rv = self.fetch_events() + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID + mock_xrange.assert_called_with(channel_id, "-", "+", 100) + expected = { + "result": [ + { + "channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", + "errors": [], + "id": "1607477697866-0", + "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512", + "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463", + "status": "done", + "user_id": "1", + }, + { + "channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", + "errors": [], + "id": "1607477697993-0", + "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c", + "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d", + "status": "done", + "user_id": "1", + }, + ] + } + self.assertEqual(response, expected) + + def test_events_no_login(self): + async_query_manager.init_app(app) + rv = self.fetch_events() + assert rv.status_code == 401 + + def test_events_no_token(self): + self.login(username="admin") + self.client.set_cookie( + "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "" + ) + rv = self.fetch_events() + assert rv.status_code == 401 diff --git a/tests/cache_tests.py b/tests/cache_tests.py index 0b887622ef..3ffd52a378 100644 --- a/tests/cache_tests.py +++ b/tests/cache_tests.py @@ -35,6 +35,7 @@ class TestCache(SupersetTestCase): cache_manager.data_cache.clear() def test_no_data_cache(self): + data_cache_config = app.config["DATA_CACHE_CONFIG"] app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"} cache_manager.init_app(app) @@ -48,11 +49,15 @@ class TestCache(SupersetTestCase): resp_from_cache = self.get_json_resp( json_endpoint, {"form_data": json.dumps(slc.viz.form_data)} ) + # restore DATA_CACHE_CONFIG + app.config["DATA_CACHE_CONFIG"] = data_cache_config self.assertFalse(resp["is_cached"]) self.assertFalse(resp_from_cache["is_cached"]) def test_slice_data_cache(self): # Override cache config + data_cache_config = app.config["DATA_CACHE_CONFIG"] + cache_default_timeout = app.config["CACHE_DEFAULT_TIMEOUT"] app.config["CACHE_DEFAULT_TIMEOUT"] = 100 app.config["DATA_CACHE_CONFIG"] = { "CACHE_TYPE": "simple", @@ -87,5 +92,6 @@ class TestCache(SupersetTestCase): self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"])) # reset cache config - app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"} + app.config["DATA_CACHE_CONFIG"] = data_cache_config + app.config["CACHE_DEFAULT_TIMEOUT"] = cache_default_timeout cache_manager.init_app(app) diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 27c88886e1..092c354199 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -31,21 +31,21 @@ from sqlalchemy import and_ from sqlalchemy.sql import func from tests.test_app import app -from superset.connectors.sqla.models import SqlaTable -from superset.utils.core import AnnotationType, get_example_database -from tests.fixtures.energy_dashboard import load_energy_table_with_slice -from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice +from superset.charts.commands.data import ChartDataCommand from superset.connectors.connector_registry import ConnectorRegistry -from superset.extensions import db, security_manager +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import async_query_manager, cache_manager, db, security_manager from superset.models.annotations import AnnotationLayer from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.reports import ReportSchedule, ReportScheduleType from superset.models.slice import Slice from superset.utils import core as utils +from superset.utils.core import AnnotationType, get_example_database from tests.base_api_tests import ApiOwnersTestCaseMixin -from tests.base_tests import SupersetTestCase +from tests.base_tests import SupersetTestCase, post_assert_metric, test_client + from tests.fixtures.importexport import ( chart_config, chart_metadata_config, @@ -53,6 +53,7 @@ from tests.fixtures.importexport import ( dataset_config, dataset_metadata_config, ) +from tests.fixtures.energy_dashboard import load_energy_table_with_slice from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice from tests.annotation_layers.fixtures import create_annotation_layers @@ -99,6 +100,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): db.session.commit() return slice + @pytest.fixture(autouse=True) + def clear_data_cache(self): + with app.app_context(): + cache_manager.data_cache.clear() + yield + @pytest.fixture() def create_charts(self): with self.create_app().app_context(): @@ -1287,6 +1294,155 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): if get_example_database().backend != "presto": assert "('boy' = 'boy')" in result + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_chart_data_async(self): + """ + Chart data API: Test chart data query (async) + """ + async_query_manager.init_app(app) + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + self.assertEqual(rv.status_code, 202) + data = json.loads(rv.data.decode("utf-8")) + keys = list(data.keys()) + self.assertCountEqual( + keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] + ) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_chart_data_async_results_type(self): + """ + Chart data API: Test chart data query non-JSON format (async) + """ + async_query_manager.init_app(app) + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + request_payload["result_type"] = "results" + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + self.assertEqual(rv.status_code, 200) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_chart_data_async_invalid_token(self): + """ + Chart data API: Test chart data query (async) + """ + async_query_manager.init_app(app) + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + test_client.set_cookie( + "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo" + ) + rv = post_assert_metric(test_client, CHART_DATA_URI, request_payload, "data") + self.assertEqual(rv.status_code, 401) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") + def test_chart_data_cache(self, load_qc_mock): + """ + Chart data cache API: Test chart data async cache request + """ + async_query_manager.init_app(app) + self.login(username="admin") + table = self.get_table_by_name("birth_names") + query_context = get_query_context(table.name, table.id, table.type) + load_qc_mock.return_value = query_context + orig_run = ChartDataCommand.run + + def mock_run(self, **kwargs): + assert kwargs["force_cached"] == True + # override force_cached to get result from DB + return orig_run(self, force_cached=False) + + with mock.patch.object(ChartDataCommand, "run", new=mock_run): + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["result"][0]["rowcount"], 45) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") + def test_chart_data_cache_run_failed(self, load_qc_mock): + """ + Chart data cache API: Test chart data async cache request with run failure + """ + async_query_manager.init_app(app) + self.login(username="admin") + table = self.get_table_by_name("birth_names") + query_context = get_query_context(table.name, table.id, table.type) + load_qc_mock.return_value = query_context + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 422) + self.assertEqual(data["message"], "Error loading data from cache") + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") + def test_chart_data_cache_no_login(self, load_qc_mock): + """ + Chart data cache API: Test chart data async cache request (no login) + """ + async_query_manager.init_app(app) + table = self.get_table_by_name("birth_names") + query_context = get_query_context(table.name, table.id, table.type) + load_qc_mock.return_value = query_context + orig_run = ChartDataCommand.run + + def mock_run(self, **kwargs): + assert kwargs["force_cached"] == True + # override force_cached to get result from DB + return orig_run(self, force_cached=False) + + with mock.patch.object(ChartDataCommand, "run", new=mock_run): + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + + self.assertEqual(rv.status_code, 401) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_chart_data_cache_key_error(self): + """ + Chart data cache API: Test chart data async cache request with invalid cache key + """ + async_query_manager.init_app(app) + self.login(username="admin") + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + + self.assertEqual(rv.status_code, 404) + def test_export_chart(self): """ Chart API: Test export chart diff --git a/tests/core_tests.py b/tests/core_tests.py index 6ac990ba90..e011fd69ec 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -52,6 +52,7 @@ from superset import ( from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec +from superset.extensions import async_query_manager from superset.models import core as models from superset.models.annotations import Annotation, AnnotationLayer from superset.models.dashboard import Dashboard @@ -602,10 +603,13 @@ class TestCore(SupersetTestCase): ) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] def test_cache_logging(self): + store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] + app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True girls_slice = self.get_slice("Girls", db.session) self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id)) ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == f"{girls_slice.table.id}__table" + app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys def test_shortner(self): self.login(username="admin") @@ -841,6 +845,191 @@ class TestCore(SupersetTestCase): "The datasource associated with this chart no longer exists", ) + def test_explore_json(self): + tbl_id = self.table_ids.get("birth_names") + form_data = { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{tbl_id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + self.login(username="admin") + rv = self.client.post( + "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, + ) + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["rowcount"], 2) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_explore_json_async(self): + tbl_id = self.table_ids.get("birth_names") + form_data = { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{tbl_id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + async_query_manager.init_app(app) + self.login(username="admin") + rv = self.client.post( + "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, + ) + data = json.loads(rv.data.decode("utf-8")) + keys = list(data.keys()) + + self.assertEqual(rv.status_code, 202) + self.assertCountEqual( + keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] + ) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + GLOBAL_ASYNC_QUERIES=True, + ) + def test_explore_json_async_results_format(self): + tbl_id = self.table_ids.get("birth_names") + form_data = { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{tbl_id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + async_query_manager.init_app(app) + self.login(username="admin") + rv = self.client.post( + "/superset/explore_json/?results=true", + data={"form_data": json.dumps(form_data)}, + ) + self.assertEqual(rv.status_code, 200) + + @mock.patch( + "superset.utils.cache_manager.CacheManager.cache", + new_callable=mock.PropertyMock, + ) + @mock.patch("superset.viz.BaseViz.force_cached", new_callable=mock.PropertyMock) + def test_explore_json_data(self, mock_force_cached, mock_cache): + tbl_id = self.table_ids.get("birth_names") + form_data = dict( + { + "form_data": { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{tbl_id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + } + ) + + class MockCache: + def get(self, key): + return form_data + + def set(self): + return None + + mock_cache.return_value = MockCache() + mock_force_cached.return_value = False + + self.login(username="admin") + rv = self.client.get("/superset/explore_json/data/valid-cache-key") + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["rowcount"], 2) + + @mock.patch( + "superset.utils.cache_manager.CacheManager.cache", + new_callable=mock.PropertyMock, + ) + def test_explore_json_data_no_login(self, mock_cache): + tbl_id = self.table_ids.get("birth_names") + form_data = dict( + { + "form_data": { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{tbl_id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + } + ) + + class MockCache: + def get(self, key): + return form_data + + def set(self): + return None + + mock_cache.return_value = MockCache() + + rv = self.client.get("/superset/explore_json/data/valid-cache-key") + self.assertEqual(rv.status_code, 401) + + def test_explore_json_data_invalid_cache_key(self): + self.login(username="admin") + cache_key = "invalid-cache-key" + rv = self.client.get(f"/superset/explore_json/data/{cache_key}") + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 404) + self.assertEqual(data["error"], "Cached data not found") + @mock.patch( "superset.security.SupersetSecurityManager.get_schemas_accessible_by_user" ) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 3833ee0522..5bccf07507 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -82,7 +82,7 @@ class TestQueryContext(SupersetTestCase): # construct baseline cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - cache_key_original = query_context.cache_key(query_object) + cache_key_original = query_context.query_cache_key(query_object) # make temporary change and revert it to refresh the changed_on property datasource = ConnectorRegistry.get_datasource( @@ -99,7 +99,7 @@ class TestQueryContext(SupersetTestCase): # create new QueryContext with unchanged attributes and extract new cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - cache_key_new = query_context.cache_key(query_object) + cache_key_new = query_context.query_cache_key(query_object) # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new) @@ -115,20 +115,20 @@ class TestQueryContext(SupersetTestCase): # construct baseline cache_key from query_context with post processing operation query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - cache_key_original = query_context.cache_key(query_object) + cache_key_original = query_context.query_cache_key(query_object) # ensure added None post_processing operation doesn't change cache_key payload["queries"][0]["post_processing"].append(None) query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - cache_key_with_null = query_context.cache_key(query_object) + cache_key_with_null = query_context.query_cache_key(query_object) self.assertEqual(cache_key_original, cache_key_with_null) # ensure query without post processing operation is different payload["queries"][0].pop("post_processing") query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - cache_key_without_post_processing = query_context.cache_key(query_object) + cache_key_without_post_processing = query_context.query_cache_key(query_object) self.assertNotEqual(cache_key_original, cache_key_without_post_processing) def test_query_context_time_range_endpoints(self): @@ -179,13 +179,10 @@ class TestQueryContext(SupersetTestCase): query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) - data = responses[0]["data"] + data = responses["queries"][0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12) - ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() - assert ck.datasource_uid == f"{table.id}__table" - def test_sql_injection_via_groupby(self): """ Ensure that calling invalid columns names in groupby are caught @@ -197,7 +194,7 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["groupby"] = ["currentDatabase()"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() - assert query_payload[0].get("error") is not None + assert query_payload["queries"][0].get("error") is not None def test_sql_injection_via_columns(self): """ @@ -212,7 +209,7 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["columns"] = ["*, 'extra'"] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() - assert query_payload[0].get("error") is not None + assert query_payload["queries"][0].get("error") is not None def test_sql_injection_via_metrics(self): """ @@ -233,7 +230,7 @@ class TestQueryContext(SupersetTestCase): ] query_context = ChartDataQueryContextSchema().load(payload) query_payload = query_context.get_payload() - assert query_payload[0].get("error") is not None + assert query_payload["queries"][0].get("error") is not None def test_samples_response_type(self): """ @@ -248,7 +245,7 @@ class TestQueryContext(SupersetTestCase): query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) - data = responses[0]["data"] + data = responses["queries"][0]["data"] self.assertIsInstance(data, list) self.assertEqual(len(data), 5) self.assertNotIn("sum__num", data[0]) @@ -265,7 +262,7 @@ class TestQueryContext(SupersetTestCase): query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) - response = responses[0] + response = responses["queries"][0] self.assertEqual(len(response), 2) self.assertEqual(response["language"], "sql") self.assertIn("SELECT", response["query"]) diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py index 88a79d378d..c23d2007ef 100644 --- a/tests/superset_test_config.py +++ b/tests/superset_test_config.py @@ -22,7 +22,7 @@ from tests.superset_test_custom_template_processors import CustomPrestoTemplateP AUTH_USER_REGISTRATION_ROLE = "alpha" SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db") -DEBUG = True +DEBUG = False SUPERSET_WEBSERVER_PORT = 8081 # Allowing SQLALCHEMY_DATABASE_URI and SQLALCHEMY_EXAMPLES_URI to be defined as an env vars for @@ -96,6 +96,8 @@ DATA_CACHE_CONFIG = { "CACHE_KEY_PREFIX": "superset_data_cache", } +GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me-test-secret-change-me" + class CeleryConfig(object): BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/tasks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py new file mode 100644 index 0000000000..6fe2e7c319 --- /dev/null +++ b/tests/tasks/async_queries_tests.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for async query celery jobs in Superset""" +import re +from unittest import mock +from uuid import uuid4 + +import pytest + +from superset import db +from superset.charts.commands.data import ChartDataCommand +from superset.charts.commands.exceptions import ChartDataQueryFailedError +from superset.connectors.sqla.models import SqlaTable +from superset.exceptions import SupersetException +from superset.extensions import async_query_manager +from superset.tasks.async_queries import ( + load_chart_data_into_cache, + load_explore_json_into_cache, +) +from tests.base_tests import SupersetTestCase +from tests.fixtures.query_context import get_query_context +from tests.test_app import app + + +def get_table_by_name(name: str) -> SqlaTable: + with app.app_context(): + return db.session.query(SqlaTable).filter_by(table_name=name).one() + + +class TestAsyncQueries(SupersetTestCase): + @mock.patch.object(async_query_manager, "update_job") + def test_load_chart_data_into_cache(self, mock_update_job): + async_query_manager.init_app(app) + table = get_table_by_name("birth_names") + form_data = get_query_context(table.name, table.id, table.type) + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": 1, + "status": "pending", + "errors": [], + } + + load_chart_data_into_cache(job_metadata, form_data) + + mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) + + @mock.patch.object( + ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo") + ) + @mock.patch.object(async_query_manager, "update_job") + def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): + async_query_manager.init_app(app) + table = get_table_by_name("birth_names") + form_data = get_query_context(table.name, table.id, table.type) + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": 1, + "status": "pending", + "errors": [], + } + with pytest.raises(ChartDataQueryFailedError): + load_chart_data_into_cache(job_metadata, form_data) + + mock_run_command.assert_called_with(cache=True) + errors = [{"message": "Error: foo"}] + mock_update_job.assert_called_with(job_metadata, "error", errors=errors) + + @mock.patch.object(async_query_manager, "update_job") + def test_load_explore_json_into_cache(self, mock_update_job): + async_query_manager.init_app(app) + table = get_table_by_name("birth_names") + form_data = { + "queryFields": { + "metrics": "metrics", + "groupby": "groupby", + "columns": "groupby", + }, + "datasource": f"{table.id}__table", + "viz_type": "dist_bar", + "time_range_endpoints": ["inclusive", "exclusive"], + "granularity_sqla": "ds", + "time_range": "No filter", + "metrics": ["count"], + "adhoc_filters": [], + "groupby": ["gender"], + "row_limit": 100, + } + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": 1, + "status": "pending", + "errors": [], + } + + load_explore_json_into_cache(job_metadata, form_data) + + mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY) + + @mock.patch.object(async_query_manager, "update_job") + def test_load_explore_json_into_cache_error(self, mock_update_job): + async_query_manager.init_app(app) + form_data = {} + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": 1, + "status": "pending", + "errors": [], + } + + with pytest.raises(SupersetException): + load_explore_json_into_cache(job_metadata, form_data) + + errors = ["The datasource associated with this chart no longer exists"] + mock_update_job.assert_called_with(job_metadata, "error", errors=errors) diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 1dffdcd2ad..09fd3a7c91 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -163,9 +163,20 @@ class TestBaseViz(SupersetTestCase): datasource.database.cache_timeout = 1666 self.assertEqual(1666, test_viz.cache_timeout) + datasource.database.cache_timeout = None + test_viz = viz.BaseViz(datasource, form_data={}) + self.assertEqual( + app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"], + test_viz.cache_timeout, + ) + + data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None datasource.database.cache_timeout = None test_viz = viz.BaseViz(datasource, form_data={}) self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout) + # restore DATA_CACHE_CONFIG timeout + app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout class TestTableViz(SupersetTestCase):