diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 2e7d79e405..4e5a860fd3 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -623,6 +623,11 @@ cd superset-frontend
npm run test
```
+To run a single test file:
+```bash
+npm run test -- path/to/file.js
+```
+
### Integration Testing
We use [Cypress](https://www.cypress.io/) for integration tests. Tests can be run by `tox -e cypress`. To open Cypress and explore tests first setup and run test server:
diff --git a/UPDATING.md b/UPDATING.md
index 79777b79f2..9e92d36c46 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -23,7 +23,7 @@ This file documents any backwards-incompatible changes in Superset and
assists people when migrating to a new version.
## Next
-
+- [11499](https://github.com/apache/incubator-superset/pull/11499): Breaking change: `STORE_CACHE_KEYS_IN_METADATA_DB` config flag added (default=`False`) to write `CacheKey` records to the metadata DB. `CacheKey` recording was enabled by default previously.
- [11920](https://github.com/apache/incubator-superset/pull/11920): Undos the DB migration from [11714](https://github.com/apache/incubator-superset/pull/11714) to prevent adding new columns to the logs table. Deploying a sha between these two PRs may result in locking your DB.
- [11704](https://github.com/apache/incubator-superset/pull/11704) Breaking change: Jinja templating for SQL queries has been updated, removing default modules such as `datetime` and `random` and enforcing static template values. To restore or extend functionality, use `JINJA_CONTEXT_ADDONS` and `CUSTOM_TEMPLATE_PROCESSORS` in `superset_config.py`.
- [11714](https://github.com/apache/incubator-superset/pull/11714): Logs
diff --git a/setup.cfg b/setup.cfg
index 28c8e77df6..38856a3504 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
-known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
+known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,redis,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false
diff --git a/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts b/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts
new file mode 100644
index 0000000000..e42ac9152f
--- /dev/null
+++ b/superset-frontend/spec/javascripts/middleware/asyncEvent_spec.ts
@@ -0,0 +1,265 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import fetchMock from 'fetch-mock';
+import sinon from 'sinon';
+import * as featureFlags from 'src/featureFlags';
+import initAsyncEvents from 'src/middleware/asyncEvent';
+
+jest.useFakeTimers();
+
+describe('asyncEvent middleware', () => {
+ const next = sinon.spy();
+ const state = {
+ charts: {
+ 123: {
+ id: 123,
+ status: 'loading',
+ asyncJobId: 'foo123',
+ },
+ 345: {
+ id: 345,
+ status: 'loading',
+ asyncJobId: 'foo345',
+ },
+ },
+ };
+ const events = [
+ {
+ status: 'done',
+ result_url: '/api/v1/chart/data/cache-key-1',
+ job_id: 'foo123',
+ channel_id: '999',
+ errors: [],
+ },
+ {
+ status: 'done',
+ result_url: '/api/v1/chart/data/cache-key-2',
+ job_id: 'foo345',
+ channel_id: '999',
+ errors: [],
+ },
+ ];
+ const mockStore = {
+ getState: () => state,
+ dispatch: sinon.stub(),
+ };
+ const action = {
+ type: 'GENERIC_ACTION',
+ };
+ const EVENTS_ENDPOINT = 'glob:*/api/v1/async_event/*';
+ const CACHED_DATA_ENDPOINT = 'glob:*/api/v1/chart/data/*';
+ const config = {
+ GLOBAL_ASYNC_QUERIES_TRANSPORT: 'polling',
+ GLOBAL_ASYNC_QUERIES_POLLING_DELAY: 500,
+ };
+ let featureEnabledStub: any;
+
+ function setup() {
+ const getPendingComponents = sinon.stub();
+ const successAction = sinon.spy();
+ const errorAction = sinon.spy();
+ const testCallback = sinon.stub();
+ const testCallbackPromise = sinon.stub();
+ testCallbackPromise.returns(
+ new Promise(resolve => {
+ testCallback.callsFake(resolve);
+ }),
+ );
+
+ return {
+ getPendingComponents,
+ successAction,
+ errorAction,
+ testCallback,
+ testCallbackPromise,
+ };
+ }
+
+ beforeEach(() => {
+ fetchMock.get(EVENTS_ENDPOINT, {
+ status: 200,
+ body: { result: [] },
+ });
+ fetchMock.get(CACHED_DATA_ENDPOINT, {
+ status: 200,
+ body: { result: { some: 'data' } },
+ });
+ featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
+ featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(true);
+ });
+ afterEach(() => {
+ fetchMock.reset();
+ next.resetHistory();
+ featureEnabledStub.restore();
+ });
+ afterAll(fetchMock.reset);
+
+ it('should initialize and call next', () => {
+ const { getPendingComponents, successAction, errorAction } = setup();
+ getPendingComponents.returns([]);
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ });
+ asyncEventMiddleware(mockStore)(next)(action);
+ expect(next.callCount).toBe(1);
+ });
+
+ it('should fetch events when there are pending components', () => {
+ const {
+ getPendingComponents,
+ successAction,
+ errorAction,
+ testCallback,
+ testCallbackPromise,
+ } = setup();
+ getPendingComponents.returns(Object.values(state.charts));
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ processEventsCallback: testCallback,
+ });
+
+ asyncEventMiddleware(mockStore)(next)(action);
+
+ return testCallbackPromise().then(() => {
+ expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
+ });
+ });
+
+ it('should fetch cached when there are successful events', () => {
+ const {
+ getPendingComponents,
+ successAction,
+ errorAction,
+ testCallback,
+ testCallbackPromise,
+ } = setup();
+ fetchMock.reset();
+ fetchMock.get(EVENTS_ENDPOINT, {
+ status: 200,
+ body: { result: events },
+ });
+ fetchMock.get(CACHED_DATA_ENDPOINT, {
+ status: 200,
+ body: { result: { some: 'data' } },
+ });
+ getPendingComponents.returns(Object.values(state.charts));
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ processEventsCallback: testCallback,
+ });
+
+ asyncEventMiddleware(mockStore)(next)(action);
+
+ return testCallbackPromise().then(() => {
+ expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
+ expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
+ expect(successAction.callCount).toBe(2);
+ });
+ });
+
+ it('should call errorAction for cache fetch error responses', () => {
+ const {
+ getPendingComponents,
+ successAction,
+ errorAction,
+ testCallback,
+ testCallbackPromise,
+ } = setup();
+ fetchMock.reset();
+ fetchMock.get(EVENTS_ENDPOINT, {
+ status: 200,
+ body: { result: events },
+ });
+ fetchMock.get(CACHED_DATA_ENDPOINT, {
+ status: 400,
+ body: { errors: ['error'] },
+ });
+ getPendingComponents.returns(Object.values(state.charts));
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ processEventsCallback: testCallback,
+ });
+
+ asyncEventMiddleware(mockStore)(next)(action);
+
+ return testCallbackPromise().then(() => {
+ expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
+ expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
+ expect(errorAction.callCount).toBe(2);
+ });
+ });
+
+ it('should handle event fetching error responses', () => {
+ const {
+ getPendingComponents,
+ successAction,
+ errorAction,
+ testCallback,
+ testCallbackPromise,
+ } = setup();
+ fetchMock.reset();
+ fetchMock.get(EVENTS_ENDPOINT, {
+ status: 400,
+ body: { message: 'error' },
+ });
+ getPendingComponents.returns(Object.values(state.charts));
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ processEventsCallback: testCallback,
+ });
+
+ asyncEventMiddleware(mockStore)(next)(action);
+
+ return testCallbackPromise().then(() => {
+ expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
+ });
+ });
+
+ it('should not fetch events when async queries are disabled', () => {
+ featureEnabledStub.restore();
+ featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
+ featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(false);
+ const { getPendingComponents, successAction, errorAction } = setup();
+ getPendingComponents.returns(Object.values(state.charts));
+ const asyncEventMiddleware = initAsyncEvents({
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ });
+
+ asyncEventMiddleware(mockStore)(next)(action);
+ expect(getPendingComponents.called).toBe(false);
+ });
+});
diff --git a/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts b/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts
index 8519b71206..8e89fec284 100644
--- a/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts
+++ b/superset-frontend/spec/javascripts/utils/getClientErrorObject_spec.ts
@@ -17,7 +17,7 @@
* under the License.
*/
import { ErrorTypeEnum } from 'src/components/ErrorMessage/types';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
describe('getClientErrorObject()', () => {
it('Returns a Promise', () => {
diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js
index 7c1bce3e1f..ebfca9116f 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.js
@@ -30,7 +30,7 @@ import {
addSuccessToast as addSuccessToastAction,
addWarningToast as addWarningToastAction,
} from '../../messageToasts/actions/index';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
import COMMON_ERR_MESSAGES from '../../utils/errorMessages';
export const RESET_STATE = 'RESET_STATE';
diff --git a/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx b/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx
index dc9fb4267f..c094d08775 100644
--- a/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx
+++ b/superset-frontend/src/SqlLab/components/ShareSqlLabQuery.jsx
@@ -25,7 +25,7 @@ import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
import Button from 'src/components/Button';
import CopyToClipboard from '../../components/CopyToClipboard';
import { storeQuery } from '../../utils/common';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
import withToasts from '../../messageToasts/enhancers/withToasts';
const propTypes = {
diff --git a/superset-frontend/src/chart/chartAction.js b/superset-frontend/src/chart/chartAction.js
index f6329a80a5..0198b33a3b 100644
--- a/superset-frontend/src/chart/chartAction.js
+++ b/superset-frontend/src/chart/chartAction.js
@@ -38,7 +38,7 @@ import {
import { addDangerToast } from '../messageToasts/actions';
import { logEvent } from '../logger/actions';
import { Logger, LOG_ACTIONS_LOAD_CHART } from '../logger/LogUtils';
-import getClientErrorObject from '../utils/getClientErrorObject';
+import { getClientErrorObject } from '../utils/getClientErrorObject';
import { allowCrossDomain as domainShardingEnabled } from '../utils/hostNamesConfig';
export const CHART_UPDATE_STARTED = 'CHART_UPDATE_STARTED';
@@ -66,6 +66,11 @@ export function chartUpdateFailed(queryResponse, key) {
return { type: CHART_UPDATE_FAILED, queryResponse, key };
}
+export const CHART_UPDATE_QUEUED = 'CHART_UPDATE_QUEUED';
+export function chartUpdateQueued(asyncJobMeta, key) {
+ return { type: CHART_UPDATE_QUEUED, asyncJobMeta, key };
+}
+
export const CHART_RENDERING_FAILED = 'CHART_RENDERING_FAILED';
export function chartRenderingFailed(error, key, stackTrace) {
return { type: CHART_RENDERING_FAILED, error, key, stackTrace };
@@ -356,6 +361,12 @@ export function exploreJSON(
const chartDataRequestCaught = chartDataRequest
.then(response => {
+ if (isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES)) {
+ // deal with getChartDataRequest transforming the response data
+ const result = 'result' in response ? response.result[0] : response;
+ return dispatch(chartUpdateQueued(result, key));
+ }
+
// new API returns an object with an array of restults
// problem: response holds a list of results, when before we were just getting one result.
// How to make the entire app compatible with multiple results?
diff --git a/superset-frontend/src/chart/chartReducer.js b/superset-frontend/src/chart/chartReducer.js
index b3e72124f9..fc28e99d2c 100644
--- a/superset-frontend/src/chart/chartReducer.js
+++ b/superset-frontend/src/chart/chartReducer.js
@@ -71,6 +71,14 @@ export default function chartReducer(charts = {}, action) {
chartUpdateEndTime: now(),
};
},
+ [actions.CHART_UPDATE_QUEUED](state) {
+ return {
+ ...state,
+ asyncJobId: action.asyncJobMeta.job_id,
+ chartStatus: 'loading',
+ chartUpdateEndTime: now(),
+ };
+ },
[actions.CHART_RENDERING_SUCCEEDED](state) {
return { ...state, chartStatus: 'rendered', chartUpdateEndTime: now() };
},
diff --git a/superset-frontend/src/components/AsyncSelect.jsx b/superset-frontend/src/components/AsyncSelect.jsx
index fc9c5eeb59..93bbb34087 100644
--- a/superset-frontend/src/components/AsyncSelect.jsx
+++ b/superset-frontend/src/components/AsyncSelect.jsx
@@ -21,7 +21,7 @@ import PropTypes from 'prop-types';
// TODO: refactor this with `import { AsyncSelect } from src/components/Select`
import { Select } from 'src/components/Select';
import { t, SupersetClient } from '@superset-ui/core';
-import getClientErrorObject from '../utils/getClientErrorObject';
+import { getClientErrorObject } from '../utils/getClientErrorObject';
const propTypes = {
dataEndpoint: PropTypes.string.isRequired,
diff --git a/superset-frontend/src/dashboard/actions/dashboardState.js b/superset-frontend/src/dashboard/actions/dashboardState.js
index a499fa9a8e..fc31a8501a 100644
--- a/superset-frontend/src/dashboard/actions/dashboardState.js
+++ b/superset-frontend/src/dashboard/actions/dashboardState.js
@@ -29,7 +29,7 @@ import {
updateDirectPathToFilter,
} from './dashboardFilters';
import { applyDefaultFormData } from '../../explore/store';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
import { SAVE_TYPE_OVERWRITE } from '../util/constants';
import {
addSuccessToast,
diff --git a/superset-frontend/src/dashboard/actions/datasources.js b/superset-frontend/src/dashboard/actions/datasources.js
index 40cba8559a..4277edc661 100644
--- a/superset-frontend/src/dashboard/actions/datasources.js
+++ b/superset-frontend/src/dashboard/actions/datasources.js
@@ -17,7 +17,7 @@
* under the License.
*/
import { SupersetClient } from '@superset-ui/core';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
export const SET_DATASOURCE = 'SET_DATASOURCE';
export function setDatasource(datasource, key) {
diff --git a/superset-frontend/src/dashboard/actions/sliceEntities.js b/superset-frontend/src/dashboard/actions/sliceEntities.js
index 01512ddb14..69472d9155 100644
--- a/superset-frontend/src/dashboard/actions/sliceEntities.js
+++ b/superset-frontend/src/dashboard/actions/sliceEntities.js
@@ -22,7 +22,7 @@ import rison from 'rison';
import { addDangerToast } from 'src/messageToasts/actions';
import { getDatasourceParameter } from 'src/modules/utils';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
export const SET_ALL_SLICES = 'SET_ALL_SLICES';
export function setAllSlices(slices) {
diff --git a/superset-frontend/src/dashboard/components/PropertiesModal.jsx b/superset-frontend/src/dashboard/components/PropertiesModal.jsx
index 019118591d..dd17dcb7d5 100644
--- a/superset-frontend/src/dashboard/components/PropertiesModal.jsx
+++ b/superset-frontend/src/dashboard/components/PropertiesModal.jsx
@@ -35,7 +35,7 @@ import FormLabel from 'src/components/FormLabel';
import { JsonEditor } from 'src/components/AsyncAceEditor';
import ColorSchemeControlWrapper from 'src/dashboard/components/ColorSchemeControlWrapper';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
import withToasts from '../../messageToasts/enhancers/withToasts';
import '../stylesheets/buttons.less';
diff --git a/superset-frontend/src/dashboard/index.jsx b/superset-frontend/src/dashboard/index.jsx
index 3937a357df..9fe82346c3 100644
--- a/superset-frontend/src/dashboard/index.jsx
+++ b/superset-frontend/src/dashboard/index.jsx
@@ -24,7 +24,9 @@ import { initFeatureFlags } from 'src/featureFlags';
import { initEnhancer } from '../reduxUtils';
import getInitialState from './reducers/getInitialState';
import rootReducer from './reducers/index';
+import initAsyncEvents from '../middleware/asyncEvent';
import logger from '../middleware/loggerMiddleware';
+import * as actions from '../chart/chartAction';
import App from './App';
@@ -33,10 +35,23 @@ const bootstrapData = JSON.parse(appContainer.getAttribute('data-bootstrap'));
initFeatureFlags(bootstrapData.common.feature_flags);
const initState = getInitialState(bootstrapData);
+const asyncEventMiddleware = initAsyncEvents({
+ config: bootstrapData.common.conf,
+ getPendingComponents: ({ charts }) =>
+ Object.values(charts).filter(c => c.chartStatus === 'loading'),
+ successAction: (componentId, componentData) =>
+ actions.chartUpdateSucceeded(componentData, componentId),
+ errorAction: (componentId, response) =>
+ actions.chartUpdateFailed(response, componentId),
+});
+
const store = createStore(
rootReducer,
initState,
- compose(applyMiddleware(thunk, logger), initEnhancer(false)),
+ compose(
+ applyMiddleware(thunk, logger, asyncEventMiddleware),
+ initEnhancer(false),
+ ),
);
ReactDOM.render(, document.getElementById('app'));
diff --git a/superset-frontend/src/datasource/ChangeDatasourceModal.tsx b/superset-frontend/src/datasource/ChangeDatasourceModal.tsx
index 4f1597be90..81b9b8cb8c 100644
--- a/superset-frontend/src/datasource/ChangeDatasourceModal.tsx
+++ b/superset-frontend/src/datasource/ChangeDatasourceModal.tsx
@@ -27,7 +27,7 @@ import { Alert, FormControl, FormControlProps } from 'react-bootstrap';
import { SupersetClient, t } from '@superset-ui/core';
import TableView from 'src/components/TableView';
import Modal from 'src/common/components/Modal';
-import getClientErrorObject from '../utils/getClientErrorObject';
+import { getClientErrorObject } from '../utils/getClientErrorObject';
import Loading from '../components/Loading';
import withToasts from '../messageToasts/enhancers/withToasts';
diff --git a/superset-frontend/src/datasource/DatasourceEditor.jsx b/superset-frontend/src/datasource/DatasourceEditor.jsx
index aebeaeb808..dc2d44444e 100644
--- a/superset-frontend/src/datasource/DatasourceEditor.jsx
+++ b/superset-frontend/src/datasource/DatasourceEditor.jsx
@@ -33,7 +33,7 @@ import Loading from 'src/components/Loading';
import TableSelector from 'src/components/TableSelector';
import EditableTitle from 'src/components/EditableTitle';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import CheckboxControl from 'src/explore/components/controls/CheckboxControl';
import TextControl from 'src/explore/components/controls/TextControl';
diff --git a/superset-frontend/src/datasource/DatasourceModal.tsx b/superset-frontend/src/datasource/DatasourceModal.tsx
index 0a1b1c8e2b..daf47c2594 100644
--- a/superset-frontend/src/datasource/DatasourceModal.tsx
+++ b/superset-frontend/src/datasource/DatasourceModal.tsx
@@ -25,7 +25,7 @@ import Modal from 'src/common/components/Modal';
import AsyncEsmComponent from 'src/components/AsyncEsmComponent';
import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import withToasts from 'src/messageToasts/enhancers/withToasts';
const DatasourceEditor = AsyncEsmComponent(() => import('./DatasourceEditor'));
diff --git a/superset-frontend/src/explore/components/DataTablesPane.tsx b/superset-frontend/src/explore/components/DataTablesPane.tsx
index b1bc9081db..474ef94f83 100644
--- a/superset-frontend/src/explore/components/DataTablesPane.tsx
+++ b/superset-frontend/src/explore/components/DataTablesPane.tsx
@@ -23,7 +23,7 @@ import Tabs from 'src/common/components/Tabs';
import Loading from 'src/components/Loading';
import TableView, { EmptyWrapperType } from 'src/components/TableView';
import { getChartDataRequest } from 'src/chart/chartAction';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import {
CopyToClipboardButton,
FilterInput,
diff --git a/superset-frontend/src/explore/components/DisplayQueryButton.jsx b/superset-frontend/src/explore/components/DisplayQueryButton.jsx
index 1482156098..dbb5d8441b 100644
--- a/superset-frontend/src/explore/components/DisplayQueryButton.jsx
+++ b/superset-frontend/src/explore/components/DisplayQueryButton.jsx
@@ -30,7 +30,7 @@ import { DropdownButton } from 'react-bootstrap';
import { styled, t } from '@superset-ui/core';
import { Menu } from 'src/common/components';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
import CopyToClipboard from '../../components/CopyToClipboard';
import { getChartDataRequest } from '../../chart/chartAction';
import downloadAsImage from '../../utils/downloadAsImage';
diff --git a/superset-frontend/src/explore/components/PropertiesModal.tsx b/superset-frontend/src/explore/components/PropertiesModal.tsx
index 1281bc382f..240d2b6840 100644
--- a/superset-frontend/src/explore/components/PropertiesModal.tsx
+++ b/superset-frontend/src/explore/components/PropertiesModal.tsx
@@ -32,7 +32,7 @@ import rison from 'rison';
import { t, SupersetClient } from '@superset-ui/core';
import Chart, { Slice } from 'src/types/Chart';
import FormLabel from 'src/components/FormLabel';
-import getClientErrorObject from '../../utils/getClientErrorObject';
+import { getClientErrorObject } from '../../utils/getClientErrorObject';
type PropertiesModalProps = {
slice: Slice;
diff --git a/superset-frontend/src/explore/index.jsx b/superset-frontend/src/explore/index.jsx
index 25a704757a..83e4bc63dc 100644
--- a/superset-frontend/src/explore/index.jsx
+++ b/superset-frontend/src/explore/index.jsx
@@ -25,6 +25,8 @@ import { initFeatureFlags } from '../featureFlags';
import { initEnhancer } from '../reduxUtils';
import getInitialState from './reducers/getInitialState';
import rootReducer from './reducers/index';
+import initAsyncEvents from '../middleware/asyncEvent';
+import * as actions from '../chart/chartAction';
import App from './App';
@@ -35,10 +37,23 @@ const bootstrapData = JSON.parse(
initFeatureFlags(bootstrapData.common.feature_flags);
const initState = getInitialState(bootstrapData);
+const asyncEventMiddleware = initAsyncEvents({
+ config: bootstrapData.common.conf,
+ getPendingComponents: ({ charts }) =>
+ Object.values(charts).filter(c => c.chartStatus === 'loading'),
+ successAction: (componentId, componentData) =>
+ actions.chartUpdateSucceeded(componentData, componentId),
+ errorAction: (componentId, response) =>
+ actions.chartUpdateFailed(response, componentId),
+});
+
const store = createStore(
rootReducer,
initState,
- compose(applyMiddleware(thunk, logger), initEnhancer(false)),
+ compose(
+ applyMiddleware(thunk, logger, asyncEventMiddleware),
+ initEnhancer(false),
+ ),
);
ReactDOM.render(, document.getElementById('app'));
diff --git a/superset-frontend/src/featureFlags.ts b/superset-frontend/src/featureFlags.ts
index 1e024a5165..93b909ddcb 100644
--- a/superset-frontend/src/featureFlags.ts
+++ b/superset-frontend/src/featureFlags.ts
@@ -34,6 +34,7 @@ export enum FeatureFlag {
DISPLAY_MARKDOWN_HTML = 'DISPLAY_MARKDOWN_HTML',
ESCAPE_MARKDOWN_HTML = 'ESCAPE_MARKDOWN_HTML',
VERSIONED_EXPORT = 'VERSIONED_EXPORT',
+ GLOBAL_ASYNC_QUERIES = 'GLOBAL_ASYNC_QUERIES',
}
export type FeatureFlagMap = {
diff --git a/superset-frontend/src/middleware/asyncEvent.ts b/superset-frontend/src/middleware/asyncEvent.ts
new file mode 100644
index 0000000000..637bb1b38d
--- /dev/null
+++ b/superset-frontend/src/middleware/asyncEvent.ts
@@ -0,0 +1,196 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import { Middleware, MiddlewareAPI, Dispatch } from 'redux';
+import { makeApi, SupersetClient } from '@superset-ui/core';
+import { SupersetError } from 'src/components/ErrorMessage/types';
+import { isFeatureEnabled, FeatureFlag } from '../featureFlags';
+import {
+ getClientErrorObject,
+ parseErrorJson,
+} from '../utils/getClientErrorObject';
+
+export type AsyncEvent = {
+ id: string;
+ channel_id: string;
+ job_id: string;
+ user_id: string;
+ status: string;
+ errors: SupersetError[];
+ result_url: string;
+};
+
+type AsyncEventOptions = {
+ config: {
+ GLOBAL_ASYNC_QUERIES_TRANSPORT: string;
+ GLOBAL_ASYNC_QUERIES_POLLING_DELAY: number;
+ };
+ getPendingComponents: (state: any) => any[];
+ successAction: (componentId: number, componentData: any) => { type: string };
+ errorAction: (componentId: number, response: any) => { type: string };
+ processEventsCallback?: (events: AsyncEvent[]) => void; // this is currently used only for tests
+};
+
+type CachedDataResponse = {
+ componentId: number;
+ status: string;
+ data: any;
+};
+
+const initAsyncEvents = (options: AsyncEventOptions) => {
+ // TODO: implement websocket support
+ const TRANSPORT_POLLING = 'polling';
+ const {
+ config,
+ getPendingComponents,
+ successAction,
+ errorAction,
+ processEventsCallback,
+ } = options;
+ const transport = config.GLOBAL_ASYNC_QUERIES_TRANSPORT || TRANSPORT_POLLING;
+ const polling_delay = config.GLOBAL_ASYNC_QUERIES_POLLING_DELAY || 500;
+
+ const middleware: Middleware = (store: MiddlewareAPI) => (next: Dispatch) => {
+ const JOB_STATUS = {
+ PENDING: 'pending',
+ RUNNING: 'running',
+ ERROR: 'error',
+ DONE: 'done',
+ };
+ const LOCALSTORAGE_KEY = 'last_async_event_id';
+ const POLLING_URL = '/api/v1/async_event/';
+ let lastReceivedEventId: string | null;
+
+ try {
+ lastReceivedEventId = localStorage.getItem(LOCALSTORAGE_KEY);
+ } catch (err) {
+ console.warn('failed to fetch last event Id from localStorage');
+ }
+
+ const fetchEvents = makeApi<
+ { last_id?: string | null },
+ { result: AsyncEvent[] }
+ >({
+ method: 'GET',
+ endpoint: POLLING_URL,
+ });
+
+ const fetchCachedData = async (
+ asyncEvent: AsyncEvent,
+ componentId: number,
+ ): Promise => {
+ let status = 'success';
+ let data;
+ try {
+ const { json } = await SupersetClient.get({
+ endpoint: asyncEvent.result_url,
+ });
+ data = 'result' in json ? json.result[0] : json;
+ } catch (response) {
+ status = 'error';
+ data = await getClientErrorObject(response);
+ }
+
+ return { componentId, status, data };
+ };
+
+ const setLastId = (asyncEvent: AsyncEvent) => {
+ lastReceivedEventId = asyncEvent.id;
+ try {
+ localStorage.setItem(LOCALSTORAGE_KEY, lastReceivedEventId as string);
+ } catch (err) {
+ console.warn('Error saving event ID to localStorage', err);
+ }
+ };
+
+ const processEvents = async () => {
+ const state = store.getState();
+ const queuedComponents = getPendingComponents(state);
+ const eventArgs = lastReceivedEventId
+ ? { last_id: lastReceivedEventId }
+ : {};
+ const events: AsyncEvent[] = [];
+ if (queuedComponents && queuedComponents.length) {
+ try {
+ const { result: events } = await fetchEvents(eventArgs);
+ if (events && events.length) {
+ const componentsByJobId = queuedComponents.reduce((acc, item) => {
+ acc[item.asyncJobId] = item;
+ return acc;
+ }, {});
+ const fetchDataEvents: Promise[] = [];
+ events.forEach((asyncEvent: AsyncEvent) => {
+ const component = componentsByJobId[asyncEvent.job_id];
+ if (!component) {
+ console.warn(
+ 'component not found for job_id',
+ asyncEvent.job_id,
+ );
+ return setLastId(asyncEvent);
+ }
+ const componentId = component.id;
+ switch (asyncEvent.status) {
+ case JOB_STATUS.DONE:
+ fetchDataEvents.push(
+ fetchCachedData(asyncEvent, componentId),
+ );
+ break;
+ case JOB_STATUS.ERROR:
+ store.dispatch(
+ errorAction(componentId, parseErrorJson(asyncEvent)),
+ );
+ break;
+ default:
+ console.warn('received event with status', asyncEvent.status);
+ }
+
+ return setLastId(asyncEvent);
+ });
+
+ const fetchResults = await Promise.all(fetchDataEvents);
+ fetchResults.forEach(result => {
+ if (result.status === 'success') {
+ store.dispatch(successAction(result.componentId, result.data));
+ } else {
+ store.dispatch(errorAction(result.componentId, result.data));
+ }
+ });
+ }
+ } catch (err) {
+ console.warn(err);
+ }
+ }
+
+ if (processEventsCallback) processEventsCallback(events);
+
+ return setTimeout(processEvents, polling_delay);
+ };
+
+ if (
+ isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES) &&
+ transport === TRANSPORT_POLLING
+ )
+ processEvents();
+
+ return action => next(action);
+ };
+
+ return middleware;
+};
+
+export default initAsyncEvents;
diff --git a/superset-frontend/src/setup/setupApp.ts b/superset-frontend/src/setup/setupApp.ts
index 1eef937bc7..740c97b355 100644
--- a/superset-frontend/src/setup/setupApp.ts
+++ b/superset-frontend/src/setup/setupApp.ts
@@ -19,7 +19,8 @@
/* eslint global-require: 0 */
import $ from 'jquery';
import { SupersetClient } from '@superset-ui/core';
-import getClientErrorObject, {
+import {
+ getClientErrorObject,
ClientErrorObject,
} from '../utils/getClientErrorObject';
import setupErrorMessages from './setupErrorMessages';
diff --git a/superset-frontend/src/utils/common.js b/superset-frontend/src/utils/common.js
index 033382294e..2753c0205c 100644
--- a/superset-frontend/src/utils/common.js
+++ b/superset-frontend/src/utils/common.js
@@ -21,7 +21,7 @@ import {
getTimeFormatter,
TimeFormats,
} from '@superset-ui/core';
-import getClientErrorObject from './getClientErrorObject';
+import { getClientErrorObject } from './getClientErrorObject';
// ATTENTION: If you change any constants, make sure to also change constants.py
diff --git a/superset-frontend/src/utils/getClientErrorObject.ts b/superset-frontend/src/utils/getClientErrorObject.ts
index 274b8b4d93..269126169f 100644
--- a/superset-frontend/src/utils/getClientErrorObject.ts
+++ b/superset-frontend/src/utils/getClientErrorObject.ts
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-import { SupersetClientResponse, t } from '@superset-ui/core';
+import { JsonObject, SupersetClientResponse, t } from '@superset-ui/core';
import {
SupersetError,
ErrorTypeEnum,
@@ -36,7 +36,33 @@ export type ClientErrorObject = {
stacktrace?: string;
} & Partial;
-export default function getClientErrorObject(
+export function parseErrorJson(responseObject: JsonObject): ClientErrorObject {
+ let error = { ...responseObject };
+ // Backwards compatibility for old error renderers with the new error object
+ if (error.errors && error.errors.length > 0) {
+ error.error = error.description = error.errors[0].message;
+ error.link = error.errors[0]?.extra?.link;
+ }
+
+ if (error.stack) {
+ error = {
+ ...error,
+ error:
+ t('Unexpected error: ') +
+ (error.description || t('(no description, click to see stack trace)')),
+ stacktrace: error.stack,
+ };
+ } else if (error.responseText && error.responseText.indexOf('CSRF') >= 0) {
+ error = {
+ ...error,
+ error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
+ };
+ }
+
+ return { ...error, error: error.error }; // explicit ClientErrorObject
+}
+
+export function getClientErrorObject(
response: SupersetClientResponse | (Response & { timeout: number }) | string,
): Promise {
// takes a SupersetClientResponse as input, attempts to read response as Json if possible,
@@ -54,33 +80,8 @@ export default function getClientErrorObject(
.clone()
.json()
.then(errorJson => {
- let error = { ...responseObject, ...errorJson };
-
- // Backwards compatibility for old error renderers with the new error object
- if (error.errors && error.errors.length > 0) {
- error.error = error.description = error.errors[0].message;
- error.link = error.errors[0]?.extra?.link;
- }
-
- if (error.stack) {
- error = {
- ...error,
- error:
- t('Unexpected error: ') +
- (error.description ||
- t('(no description, click to see stack trace)')),
- stacktrace: error.stack,
- };
- } else if (
- error.responseText &&
- error.responseText.indexOf('CSRF') >= 0
- ) {
- error = {
- ...error,
- error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
- };
- }
- resolve(error);
+ const error = { ...responseObject, ...errorJson };
+ resolve(parseErrorJson(error));
})
.catch(() => {
// fall back to reading as text
diff --git a/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx b/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx
index dba16e3465..d1b7ef5355 100644
--- a/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx
+++ b/superset-frontend/src/views/CRUD/annotation/AnnotationList.tsx
@@ -29,7 +29,7 @@ import ConfirmStatusChange from 'src/components/ConfirmStatusChange';
import DeleteModal from 'src/components/DeleteModal';
import ListView, { ListViewProps } from 'src/components/ListView';
import SubMenu, { SubMenuProps } from 'src/components/Menu/SubMenu';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import withToasts from 'src/messageToasts/enhancers/withToasts';
import { IconName } from 'src/components/Icon';
import { useListViewResource } from 'src/views/CRUD/hooks';
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx
index fa8a3625ab..b8f061c1b9 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx
@@ -21,7 +21,7 @@ import { styled, t, SupersetClient } from '@superset-ui/core';
import InfoTooltip from 'src/common/components/InfoTooltip';
import { useSingleViewResource } from 'src/views/CRUD/hooks';
import withToasts from 'src/messageToasts/enhancers/withToasts';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import Icon from 'src/components/Icon';
import Modal from 'src/common/components/Modal';
import Tabs from 'src/common/components/Tabs';
diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts
index e209815aa3..dedcadd87a 100644
--- a/superset-frontend/src/views/CRUD/hooks.ts
+++ b/superset-frontend/src/views/CRUD/hooks.ts
@@ -25,7 +25,7 @@ import { FetchDataConfig } from 'src/components/ListView';
import { FilterValue } from 'src/components/ListView/types';
import Chart, { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import { FavoriteStatus } from './types';
interface ListViewResourceState {
diff --git a/superset-frontend/src/views/CRUD/utils.tsx b/superset-frontend/src/views/CRUD/utils.tsx
index e22310dd4d..ea79121277 100644
--- a/superset-frontend/src/views/CRUD/utils.tsx
+++ b/superset-frontend/src/views/CRUD/utils.tsx
@@ -25,7 +25,7 @@ import {
} from '@superset-ui/core';
import Chart from 'src/types/Chart';
import rison from 'rison';
-import getClientErrorObject from 'src/utils/getClientErrorObject';
+import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import { FetchDataConfig } from 'src/components/ListView';
import { Dashboard } from './types';
diff --git a/superset/app.py b/superset/app.py
index bdba4cc7f5..d8f2ad252f 100644
--- a/superset/app.py
+++ b/superset/app.py
@@ -30,6 +30,7 @@ from superset.extensions import (
_event_logger,
APP_DIR,
appbuilder,
+ async_query_manager,
cache_manager,
celery_app,
csrf,
@@ -127,6 +128,7 @@ class SupersetAppInitializer:
# pylint: disable=too-many-branches
from superset.annotation_layers.api import AnnotationLayerRestApi
from superset.annotation_layers.annotations.api import AnnotationRestApi
+ from superset.async_events.api import AsyncEventsRestApi
from superset.cachekeys.api import CacheRestApi
from superset.charts.api import ChartRestApi
from superset.connectors.druid.views import (
@@ -201,6 +203,7 @@ class SupersetAppInitializer:
#
appbuilder.add_api(AnnotationRestApi)
appbuilder.add_api(AnnotationLayerRestApi)
+ appbuilder.add_api(AsyncEventsRestApi)
appbuilder.add_api(CacheRestApi)
appbuilder.add_api(ChartRestApi)
appbuilder.add_api(CssTemplateRestApi)
@@ -498,6 +501,7 @@ class SupersetAppInitializer:
self.configure_url_map_converters()
self.configure_data_sources()
self.configure_auth_provider()
+ self.configure_async_queries()
# Hook that provides administrators a handle on the Flask APP
# after initialization
@@ -648,6 +652,10 @@ class SupersetAppInitializer:
for ex in csrf_exempt_list:
csrf.exempt(ex)
+ def configure_async_queries(self) -> None:
+ if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
+ async_query_manager.init_app(self.flask_app)
+
def register_blueprints(self) -> None:
for bp in self.config["BLUEPRINTS"]:
try:
diff --git a/superset/async_events/__init__.py b/superset/async_events/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/superset/async_events/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/async_events/api.py b/superset/async_events/api.py
new file mode 100644
index 0000000000..61b85ac6f9
--- /dev/null
+++ b/superset/async_events/api.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+
+from flask import request, Response
+from flask_appbuilder import expose
+from flask_appbuilder.api import BaseApi, safe
+from flask_appbuilder.security.decorators import permission_name, protect
+
+from superset.extensions import async_query_manager, event_logger
+from superset.utils.async_query_manager import AsyncQueryTokenException
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncEventsRestApi(BaseApi):
+ resource_name = "async_event"
+ allow_browser_login = True
+ include_route_methods = {
+ "events",
+ }
+
+ @expose("/", methods=["GET"])
+ @event_logger.log_this
+ @protect()
+ @safe
+ @permission_name("list")
+ def events(self) -> Response:
+ """
+ Reads off of the Redis async events stream, using the user's JWT token and
+ optional query params for last event received.
+ ---
+ get:
+ description: >-
+ Reads off of the Redis events stream, using the user's JWT token and
+ optional query params for last event received.
+ parameters:
+ - in: query
+ name: last_id
+ description: Last ID received by the client
+ schema:
+ type: string
+ responses:
+ 200:
+ description: Async event results
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ result:
+ type: array
+ items:
+ type: object
+ properties:
+ id:
+ type: string
+ channel_id:
+ type: string
+ job_id:
+ type: string
+ user_id:
+ type: integer
+ status:
+ type: string
+ msg:
+ type: string
+ cache_key:
+ type: string
+ 401:
+ $ref: '#/components/responses/401'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ try:
+ async_channel_id = async_query_manager.parse_jwt_from_request(request)[
+ "channel"
+ ]
+ last_event_id = request.args.get("last_id")
+ events = async_query_manager.read_events(async_channel_id, last_event_id)
+
+ except AsyncQueryTokenException:
+ return self.response_401()
+
+ return self.response(200, result=events)
diff --git a/superset/charts/api.py b/superset/charts/api.py
index a3a2737aae..50245be3a0 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -33,10 +33,13 @@ from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset.charts.commands.bulk_delete import BulkDeleteChartCommand
from superset.charts.commands.create import CreateChartCommand
+from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.delete import DeleteChartCommand
from superset.charts.commands.exceptions import (
ChartBulkDeleteFailedError,
ChartCreateFailedError,
+ ChartDataCacheLoadError,
+ ChartDataQueryFailedError,
ChartDeleteFailedError,
ChartForbiddenError,
ChartInvalidError,
@@ -50,7 +53,6 @@ from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.schemas import (
CHART_SCHEMAS,
- ChartDataQueryContextSchema,
ChartPostSchema,
ChartPutSchema,
get_delete_ids_schema,
@@ -67,7 +69,12 @@ from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
-from superset.utils.core import ChartDataResultFormat, json_int_dttm_ser
+from superset.utils.async_query_manager import AsyncQueryTokenException
+from superset.utils.core import (
+ ChartDataResultFormat,
+ ChartDataResultType,
+ json_int_dttm_ser,
+)
from superset.utils.screenshots import ChartScreenshot
from superset.utils.urls import get_url_path
from superset.views.base_api import (
@@ -93,6 +100,8 @@ class ChartRestApi(BaseSupersetModelRestApi):
RouteMethod.RELATED,
"bulk_delete", # not using RouteMethod since locally defined
"data",
+ "data_from_cache",
+ "viz_types",
"favorite_status",
}
class_permission_name = "SliceModelView"
@@ -448,6 +457,39 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ChartBulkDeleteFailedError as ex:
return self.response_422(message=str(ex))
+ def get_data_response(
+ self, command: ChartDataCommand, force_cached: bool = False
+ ) -> Response:
+ try:
+ result = command.run(force_cached=force_cached)
+ except ChartDataCacheLoadError as exc:
+ return self.response_422(message=exc.message)
+ except ChartDataQueryFailedError as exc:
+ return self.response_400(message=exc.message)
+
+ result_format = result["query_context"].result_format
+ if result_format == ChartDataResultFormat.CSV:
+ # return the first result
+ data = result["queries"][0]["data"]
+ return CsvResponse(
+ data,
+ status=200,
+ headers=generate_download_headers("csv"),
+ mimetype="application/csv",
+ )
+
+ if result_format == ChartDataResultFormat.JSON:
+ response_data = simplejson.dumps(
+ {"result": result["queries"]},
+ default=json_int_dttm_ser,
+ ignore_nan=True,
+ )
+ resp = make_response(response_data, 200)
+ resp.headers["Content-Type"] = "application/json; charset=utf-8"
+ return resp
+
+ return self.response_400(message=f"Unsupported result_format: {result_format}")
+
@expose("/data", methods=["POST"])
@protect()
@safe
@@ -478,8 +520,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
+ 202:
+ description: Async job details
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
+ 401:
+ $ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
@@ -490,47 +540,88 @@ class ChartRestApi(BaseSupersetModelRestApi):
json_body = json.loads(request.form["form_data"])
else:
return self.response_400(message="Request is not JSON")
+
try:
- query_context = ChartDataQueryContextSchema().load(json_body)
- except KeyError:
- return self.response_400(message="Request is incorrect")
+ command = ChartDataCommand()
+ query_context = command.set_query_context(json_body)
+ command.validate()
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
)
- try:
- query_context.raise_for_access()
except SupersetSecurityException:
return self.response_401()
- payload = query_context.get_payload()
- for query in payload:
- if query.get("error"):
- return self.response_400(message=f"Error: {query['error']}")
- result_format = query_context.result_format
- response = self.response_400(
- message=f"Unsupported result_format: {result_format}"
- )
+ # TODO: support CSV, SQL query and other non-JSON types
+ if (
+ is_feature_enabled("GLOBAL_ASYNC_QUERIES")
+ and query_context.result_format == ChartDataResultFormat.JSON
+ and query_context.result_type == ChartDataResultType.FULL
+ ):
- if result_format == ChartDataResultFormat.CSV:
- # return the first result
- result = payload[0]["data"]
- response = CsvResponse(
- result,
- status=200,
- headers=generate_download_headers("csv"),
- mimetype="application/csv",
+ try:
+ command.validate_async_request(request)
+ except AsyncQueryTokenException:
+ return self.response_401()
+
+ result = command.run_async()
+ return self.response(202, **result)
+
+ return self.get_data_response(command)
+
+ @expose("/data/", methods=["GET"])
+ @event_logger.log_this
+ @protect()
+ @safe
+ @statsd_metrics
+ def data_from_cache(self, cache_key: str) -> Response:
+ """
+ Takes a query context cache key and returns payload
+ data response for the given query.
+ ---
+ get:
+ description: >-
+ Takes a query context cache key and returns payload data
+ response for the given query.
+ parameters:
+ - in: path
+ schema:
+ type: string
+ name: cache_key
+ responses:
+ 200:
+ description: Query result
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/ChartDataResponseSchema"
+ 400:
+ $ref: '#/components/responses/400'
+ 401:
+ $ref: '#/components/responses/401'
+ 404:
+ $ref: '#/components/responses/404'
+ 422:
+ $ref: '#/components/responses/422'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ command = ChartDataCommand()
+ try:
+ cached_data = command.load_query_context_from_cache(cache_key)
+ command.set_query_context(cached_data)
+ command.validate()
+ except ChartDataCacheLoadError:
+ return self.response_404()
+ except ValidationError as error:
+ return self.response_400(
+ message=_("Request is incorrect: %(error)s", error=error.messages)
)
+ except SupersetSecurityException as exc:
+ logger.info(exc)
+ return self.response_401()
- if result_format == ChartDataResultFormat.JSON:
- response_data = simplejson.dumps(
- {"result": payload}, default=json_int_dttm_ser, ignore_nan=True
- )
- resp = make_response(response_data, 200)
- resp.headers["Content-Type"] = "application/json; charset=utf-8"
- response = resp
-
- return response
+ return self.get_data_response(command, True)
@expose("//cache_screenshot/", methods=["GET"])
@protect()
diff --git a/superset/charts/commands/data.py b/superset/charts/commands/data.py
new file mode 100644
index 0000000000..275a723d60
--- /dev/null
+++ b/superset/charts/commands/data.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from typing import Any, Dict
+
+from flask import Request
+from marshmallow import ValidationError
+
+from superset import cache
+from superset.charts.commands.exceptions import (
+ ChartDataCacheLoadError,
+ ChartDataQueryFailedError,
+)
+from superset.charts.schemas import ChartDataQueryContextSchema
+from superset.commands.base import BaseCommand
+from superset.common.query_context import QueryContext
+from superset.exceptions import CacheLoadError
+from superset.extensions import async_query_manager
+from superset.tasks.async_queries import load_chart_data_into_cache
+
+logger = logging.getLogger(__name__)
+
+
+class ChartDataCommand(BaseCommand):
+ def __init__(self) -> None:
+ self._form_data: Dict[str, Any]
+ self._query_context: QueryContext
+ self._async_channel_id: str
+
+ def run(self, **kwargs: Any) -> Dict[str, Any]:
+ # caching is handled in query_context.get_df_payload
+ # (also evals `force` property)
+ cache_query_context = kwargs.get("cache", False)
+ force_cached = kwargs.get("force_cached", False)
+ try:
+ payload = self._query_context.get_payload(
+ cache_query_context=cache_query_context, force_cached=force_cached
+ )
+ except CacheLoadError as exc:
+ raise ChartDataCacheLoadError(exc.message)
+
+ # TODO: QueryContext should support SIP-40 style errors
+ for query in payload["queries"]:
+ if query.get("error"):
+ raise ChartDataQueryFailedError(f"Error: {query['error']}")
+
+ return_value = {
+ "query_context": self._query_context,
+ "queries": payload["queries"],
+ }
+ if cache_query_context:
+ return_value.update(cache_key=payload["cache_key"])
+
+ return return_value
+
+ def run_async(self) -> Dict[str, Any]:
+ job_metadata = async_query_manager.init_job(self._async_channel_id)
+ load_chart_data_into_cache.delay(job_metadata, self._form_data)
+
+ return job_metadata
+
+ def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
+ self._form_data = form_data
+ try:
+ self._query_context = ChartDataQueryContextSchema().load(self._form_data)
+ except KeyError:
+ raise ValidationError("Request is incorrect")
+ except ValidationError as error:
+ raise error
+
+ return self._query_context
+
+ def validate(self) -> None:
+ self._query_context.raise_for_access()
+
+ def validate_async_request(self, request: Request) -> None:
+ jwt_data = async_query_manager.parse_jwt_from_request(request)
+ self._async_channel_id = jwt_data["channel"]
+
+ def load_query_context_from_cache( # pylint: disable=no-self-use
+ self, cache_key: str
+ ) -> Dict[str, Any]:
+ cache_value = cache.get(cache_key)
+ if not cache_value:
+ raise ChartDataCacheLoadError("Cached data not found")
+
+ return cache_value["data"]
diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py
index 51b5ca84f0..d1fa375c02 100644
--- a/superset/charts/commands/exceptions.py
+++ b/superset/charts/commands/exceptions.py
@@ -90,6 +90,14 @@ class ChartBulkDeleteFailedError(DeleteFailedError):
message = _("Charts could not be deleted.")
+class ChartDataQueryFailedError(CommandException):
+ pass
+
+
+class ChartDataCacheLoadError(CommandException):
+ pass
+
+
class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError):
message = _("There are associated alerts or reports")
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 347189a13e..5de166c86e 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -1105,6 +1105,18 @@ class ChartDataResponseSchema(Schema):
)
+class ChartDataAsyncResponseSchema(Schema):
+ channel_id = fields.String(
+ description="Unique session async channel ID", allow_none=False,
+ )
+ job_id = fields.String(description="Unique async job ID", allow_none=False,)
+ user_id = fields.String(description="Requesting user ID", allow_none=True,)
+ status = fields.String(description="Status value for async job", allow_none=False,)
+ result_url = fields.String(
+ description="Unique result URL for fetching async query data", allow_none=False,
+ )
+
+
class ChartFavStarResponseResult(Schema):
id = fields.Integer(description="The Chart id")
value = fields.Boolean(description="The FaveStar value")
@@ -1130,6 +1142,7 @@ class ImportV1ChartSchema(Schema):
CHART_SCHEMAS = (
ChartDataQueryContextSchema,
ChartDataResponseSchema,
+ ChartDataAsyncResponseSchema,
# TODO: These should optimally be included in the QueryContext schema as an `anyOf`
# in ChartDataPostPricessingOperation.options, but since `anyOf` is not
# by Marshmallow<3, this is not currently possible.
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 25113c36ae..a7900cf667 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -17,7 +17,7 @@
import copy
import logging
import math
-from datetime import datetime, timedelta
+from datetime import timedelta
from typing import Any, cast, ClassVar, Dict, List, Optional, Union
import numpy as np
@@ -30,13 +30,17 @@ from superset.charts.dao import ChartDAO
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
-from superset.exceptions import QueryObjectValidationError, SupersetException
+from superset.exceptions import (
+ CacheLoadError,
+ QueryObjectValidationError,
+ SupersetException,
+)
from superset.extensions import cache_manager, security_manager
from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
+from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import DTTM_ALIAS
from superset.views.utils import get_viz
-from superset.viz import set_and_log_cache
config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
@@ -78,6 +82,13 @@ class QueryContext:
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or utils.ChartDataResultType.FULL
self.result_format = result_format or utils.ChartDataResultFormat.JSON
+ self.cache_values = {
+ "datasource": datasource,
+ "queries": queries,
+ "force": force,
+ "result_type": result_type,
+ "result_format": result_format,
+ }
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
@@ -142,8 +153,11 @@ class QueryContext:
return df.to_dict(orient="records")
- def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
+ def get_single_payload(
+ self, query_obj: QueryObject, **kwargs: Any
+ ) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
+ force_cached = kwargs.get("force_cached", False)
if self.result_type == utils.ChartDataResultType.QUERY:
return {
"query": self.datasource.get_query_str(query_obj.to_dict()),
@@ -159,8 +173,7 @@ class QueryContext:
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in self.datasource.columns]
- payload = self.get_df_payload(query_obj)
-
+ payload = self.get_df_payload(query_obj, force_cached=force_cached)
df = payload["df"]
status = payload["status"]
if status != utils.QueryStatus.FAILED:
@@ -186,9 +199,28 @@ class QueryContext:
return {"data": payload["data"]}
return payload
- def get_payload(self) -> List[Dict[str, Any]]:
- """Get all the payloads from the QueryObjects"""
- return [self.get_single_payload(query_object) for query_object in self.queries]
+ def get_payload(self, **kwargs: Any) -> Dict[str, Any]:
+ cache_query_context = kwargs.get("cache_query_context", False)
+ force_cached = kwargs.get("force_cached", False)
+
+ # Get all the payloads from the QueryObjects
+ query_results = [
+ self.get_single_payload(query_object, force_cached=force_cached)
+ for query_object in self.queries
+ ]
+ return_value = {"queries": query_results}
+
+ if cache_query_context:
+ cache_key = self.cache_key()
+ set_and_log_cache(
+ cache_manager.cache,
+ cache_key,
+ {"data": self.cache_values},
+ self.cache_timeout,
+ )
+ return_value["cache_key"] = cache_key # type: ignore
+
+ return return_value
@property
def cache_timeout(self) -> int:
@@ -203,7 +235,22 @@ class QueryContext:
return self.datasource.database.cache_timeout
return config["CACHE_DEFAULT_TIMEOUT"]
- def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def cache_key(self, **extra: Any) -> str:
+ """
+ The QueryContext cache key is made out of the key/values from
+ self.cached_values, plus any other key/values in `extra`. It includes only data
+ required to rehydrate a QueryContext object.
+ """
+ key_prefix = "qc-"
+ cache_dict = self.cache_values.copy()
+ cache_dict.update(extra)
+
+ return generate_cache_key(cache_dict, key_prefix)
+
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ """
+ Returns a QueryObject cache key for objects in self.queries
+ """
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
@@ -215,7 +262,7 @@ class QueryContext:
and self.datasource.is_rls_supported
else [],
changed_on=self.datasource.changed_on,
- **kwargs
+ **kwargs,
)
if query_obj
else None
@@ -298,12 +345,12 @@ class QueryContext:
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
- cache_key = self.cache_key(query_obj, **kwargs)
+ force_cached = kwargs.get("force_cached", False)
+ cache_key = self.query_cache_key(query_obj)
logger.info("Cache key: %s", cache_key)
is_loaded = False
stacktrace = None
df = pd.DataFrame()
- cached_dttm = datetime.utcnow().isoformat().split(".")[0]
cache_value = None
status = None
query = ""
@@ -327,6 +374,12 @@ class QueryContext:
)
logger.info("Serving from cache")
+ if force_cached and not is_loaded:
+ logger.warning(
+ "force_cached (QueryContext): value not found for key %s", cache_key
+ )
+ raise CacheLoadError("Error loading data from cache")
+
if query_obj and not is_loaded:
try:
invalid_columns = [
@@ -367,13 +420,11 @@ class QueryContext:
if is_loaded and cache_key and status != utils.QueryStatus.FAILED:
set_and_log_cache(
- cache_key=cache_key,
- df=df,
- query=query,
- annotation_data=annotation_data,
- cached_dttm=cached_dttm,
- cache_timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
+ cache_manager.data_cache,
+ cache_key,
+ {"df": df, "query": query, "annotation_data": annotation_data},
+ self.cache_timeout,
+ self.datasource.uid,
)
return {
"cache_key": cache_key,
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 6eb1231f2e..edcc325fed 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -277,7 +277,8 @@ class QueryObject:
:param df: DataFrame returned from database model.
:return: new DataFrame to which all post processing operations have been
applied
- :raises ChartDataValidationError: If the post processing operation in incorrect
+ :raises QueryObjectValidationError: If the post processing operation
+ is incorrect
"""
for post_process in self.post_processing:
operation = post_process.get("operation")
diff --git a/superset/config.py b/superset/config.py
index 22d3cc5262..f20cbff716 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -327,6 +327,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
"DISPLAY_MARKDOWN_HTML": True,
# When True, this escapes HTML (rather than rendering it) in Markdown components
"ESCAPE_MARKDOWN_HTML": False,
+ "GLOBAL_ASYNC_QUERIES": False,
"VERSIONED_EXPORT": False,
# Note that: RowLevelSecurityFilter is only given by default to the Admin role
# and the Admin Role does have the all_datasources security permission.
@@ -406,6 +407,9 @@ CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
# Cache for datasource metadata and query results
DATA_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
+# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation
+STORE_CACHE_KEYS_IN_METADATA_DB = False
+
# CORS Options
ENABLE_CORS = False
CORS_OPTIONS: Dict[Any, Any] = {}
@@ -965,6 +969,23 @@ SIP_15_TOAST_MESSAGE = (
# conventions and such. You can find examples in the tests.
SQLA_TABLE_MUTATOR = lambda table: table
+# Global async query config options.
+# Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled.
+GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = {
+ "port": 6379,
+ "host": "127.0.0.1",
+ "password": "",
+ "db": 0,
+}
+GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
+GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
+GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
+GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
+GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
+GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
+GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
+GLOBAL_ASYNC_QUERIES_POLLING_DELAY = 500
+
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not necessarily in pythonpath; useful
# for case where app is being executed via pex.
diff --git a/superset/exceptions.py b/superset/exceptions.py
index fd95a59e2c..52c19c3a02 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from flask_babel import gettext as _
@@ -65,6 +65,14 @@ class SupersetSecurityException(SupersetException):
self.payload = payload
+class SupersetVizException(SupersetException):
+ status = 400
+
+ def __init__(self, errors: List[SupersetError]) -> None:
+ super(SupersetVizException, self).__init__(str(errors))
+ self.errors = errors
+
+
class NoDataException(SupersetException):
status = 400
@@ -93,6 +101,10 @@ class QueryObjectValidationError(SupersetException):
status = 400
+class CacheLoadError(SupersetException):
+ status = 404
+
+
class DashboardImportException(SupersetException):
pass
diff --git a/superset/extensions.py b/superset/extensions.py
index 7011bb237d..8f5bc6d5a3 100644
--- a/superset/extensions.py
+++ b/superset/extensions.py
@@ -27,6 +27,7 @@ from flask_talisman import Talisman
from flask_wtf.csrf import CSRFProtect
from werkzeug.local import LocalProxy
+from superset.utils.async_query_manager import AsyncQueryManager
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager
from superset.utils.machine_auth import MachineAuthProviderFactory
@@ -97,6 +98,7 @@ class UIManifestProcessor:
APP_DIR = os.path.dirname(__file__)
appbuilder = AppBuilder(update_perms=False)
+async_query_manager = AsyncQueryManager()
cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
new file mode 100644
index 0000000000..b8db82b9af
--- /dev/null
+++ b/superset/tasks/async_queries.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+from typing import Any, cast, Dict, Optional
+
+from flask import current_app
+
+from superset import app
+from superset.exceptions import SupersetVizException
+from superset.extensions import async_query_manager, cache_manager, celery_app
+from superset.utils.cache import generate_cache_key, set_and_log_cache
+from superset.views.utils import get_datasource_info, get_viz
+
+logger = logging.getLogger(__name__)
+query_timeout = current_app.config[
+ "SQLLAB_ASYNC_TIME_LIMIT_SEC"
+] # TODO: new config key
+
+
+@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
+def load_chart_data_into_cache(
+ job_metadata: Dict[str, Any], form_data: Dict[str, Any],
+) -> None:
+ from superset.charts.commands.data import (
+ ChartDataCommand,
+ ) # load here due to circular imports
+
+ with app.app_context(): # type: ignore
+ try:
+ command = ChartDataCommand()
+ command.set_query_context(form_data)
+ result = command.run(cache=True)
+ cache_key = result["cache_key"]
+ result_url = f"/api/v1/chart/data/{cache_key}"
+ async_query_manager.update_job(
+ job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
+ )
+ except Exception as exc:
+ # TODO: QueryContext should support SIP-40 style errors
+ error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
+ errors = [{"message": error}]
+ async_query_manager.update_job(
+ job_metadata, async_query_manager.STATUS_ERROR, errors=errors
+ )
+ raise exc
+
+ return None
+
+
+@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
+def load_explore_json_into_cache(
+ job_metadata: Dict[str, Any],
+ form_data: Dict[str, Any],
+ response_type: Optional[str] = None,
+ force: bool = False,
+) -> None:
+ with app.app_context(): # type: ignore
+ cache_key_prefix = "ejr-" # ejr: explore_json request
+ try:
+ datasource_id, datasource_type = get_datasource_info(None, None, form_data)
+
+ viz_obj = get_viz(
+ datasource_type=cast(str, datasource_type),
+ datasource_id=datasource_id,
+ form_data=form_data,
+ force=force,
+ )
+ # run query & cache results
+ payload = viz_obj.get_payload()
+ if viz_obj.has_error(payload):
+ raise SupersetVizException(errors=payload["errors"])
+
+ # cache form_data for async retrieval
+ cache_value = {"form_data": form_data, "response_type": response_type}
+ cache_key = generate_cache_key(cache_value, cache_key_prefix)
+ set_and_log_cache(cache_manager.cache, cache_key, cache_value)
+ result_url = f"/superset/explore_json/data/{cache_key}"
+ async_query_manager.update_job(
+ job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
+ )
+ except Exception as exc:
+ if isinstance(exc, SupersetVizException):
+ errors = exc.errors # pylint: disable=no-member
+ else:
+ error = (
+ exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
+ )
+ errors = [error]
+
+ async_query_manager.update_job(
+ job_metadata, async_query_manager.STATUS_ERROR, errors=errors
+ )
+ raise exc
+
+ return None
diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py
new file mode 100644
index 0000000000..42d2c130bd
--- /dev/null
+++ b/superset/utils/async_query_manager.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import logging
+import uuid
+from typing import Any, Dict, List, Optional, Tuple
+
+import jwt
+import redis
+from flask import Flask, Request, Response, session
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncQueryTokenException(Exception):
+ pass
+
+
+class AsyncQueryJobException(Exception):
+ pass
+
+
+def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]:
+ return {
+ "channel_id": channel_id,
+ "job_id": job_id,
+ "user_id": session.get("user_id"),
+ "status": kwargs.get("status"),
+ "errors": kwargs.get("errors", []),
+ "result_url": kwargs.get("result_url"),
+ }
+
+
+def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]:
+ event_id = event_data[0]
+ event_payload = event_data[1]["data"]
+ return {"id": event_id, **json.loads(event_payload)}
+
+
+def increment_id(redis_id: str) -> str:
+ # redis stream IDs are in this format: '1607477697866-0'
+ try:
+ prefix, last = redis_id[:-1], int(redis_id[-1])
+ return prefix + str(last + 1)
+ except Exception: # pylint: disable=broad-except
+ return redis_id
+
+
+class AsyncQueryManager:
+ MAX_EVENT_COUNT = 100
+ STATUS_PENDING = "pending"
+ STATUS_RUNNING = "running"
+ STATUS_ERROR = "error"
+ STATUS_DONE = "done"
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._redis: redis.Redis
+ self._stream_prefix: str = ""
+ self._stream_limit: Optional[int]
+ self._stream_limit_firehose: Optional[int]
+ self._jwt_cookie_name: str
+ self._jwt_cookie_secure: bool = False
+ self._jwt_secret: str
+
+ def init_app(self, app: Flask) -> None:
+ config = app.config
+ if (
+ config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
+ or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
+ ):
+ raise Exception(
+ """
+ Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
+ and non-null in order to enable async queries
+ """
+ )
+
+ if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
+ raise AsyncQueryTokenException(
+ "Please provide a JWT secret at least 32 bytes long"
+ )
+
+ self._redis = redis.Redis( # type: ignore
+ **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
+ )
+ self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
+ self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
+ self._stream_limit_firehose = config[
+ "GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE"
+ ]
+ self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"]
+ self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"]
+ self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
+
+ @app.after_request
+ def validate_session( # pylint: disable=unused-variable
+ response: Response,
+ ) -> Response:
+ reset_token = False
+ user_id = session["user_id"] if "user_id" in session else None
+
+ if "async_channel_id" not in session or "async_user_id" not in session:
+ reset_token = True
+ elif user_id != session["async_user_id"]:
+ reset_token = True
+
+ if reset_token:
+ async_channel_id = str(uuid.uuid4())
+ session["async_channel_id"] = async_channel_id
+ session["async_user_id"] = user_id
+
+ sub = str(user_id) if user_id else None
+ token = self.generate_jwt({"channel": async_channel_id, "sub": sub})
+
+ response.set_cookie(
+ self._jwt_cookie_name,
+ value=token,
+ httponly=True,
+ secure=self._jwt_cookie_secure,
+ # max_age=max_age or config.cookie_max_age,
+ # domain=config.cookie_domain,
+ # path=config.access_cookie_path,
+ # samesite=config.cookie_samesite
+ )
+
+ return response
+
+ def generate_jwt(self, data: Dict[str, Any]) -> str:
+ encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256")
+ return encoded_jwt.decode("utf-8")
+
+ def parse_jwt(self, token: str) -> Dict[str, Any]:
+ data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
+ return data
+
+ def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]:
+ token = request.cookies.get(self._jwt_cookie_name)
+ if not token:
+ raise AsyncQueryTokenException("Token not preset")
+
+ try:
+ return self.parse_jwt(token)
+ except Exception as exc:
+ logger.warning(exc)
+ raise AsyncQueryTokenException("Failed to parse token")
+
+ def init_job(self, channel_id: str) -> Dict[str, Any]:
+ job_id = str(uuid.uuid4())
+ return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING)
+
+ def read_events(
+ self, channel: str, last_id: Optional[str]
+ ) -> List[Optional[Dict[str, Any]]]:
+ stream_name = f"{self._stream_prefix}{channel}"
+ start_id = increment_id(last_id) if last_id else "-"
+ results = self._redis.xrange( # type: ignore
+ stream_name, start_id, "+", self.MAX_EVENT_COUNT
+ )
+ return [] if not results else list(map(parse_event, results))
+
+ def update_job(
+ self, job_metadata: Dict[str, Any], status: str, **kwargs: Any
+ ) -> None:
+ if "channel_id" not in job_metadata:
+ raise AsyncQueryJobException("No channel ID specified")
+
+ if "job_id" not in job_metadata:
+ raise AsyncQueryJobException("No job ID specified")
+
+ updates = {"status": status, **kwargs}
+ event_data = {"data": json.dumps({**job_metadata, **updates})}
+
+ full_stream_name = f"{self._stream_prefix}full"
+ scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}"
+
+ logger.debug("********** logging event data to stream %s", scoped_stream_name)
+ logger.debug(event_data)
+
+ self._redis.xadd( # type: ignore
+ scoped_stream_name, event_data, "*", self._stream_limit
+ )
+ self._redis.xadd( # type: ignore
+ full_stream_name, event_data, "*", self._stream_limit_firehose
+ )
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index f0b24b26ae..729c316866 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -14,16 +14,65 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import hashlib
+import json
import logging
from datetime import datetime, timedelta
from functools import wraps
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Dict, Optional, Union
from flask import current_app as app, request
from flask_caching import Cache
from werkzeug.wrappers.etag import ETagResponseMixin
+from superset import db
from superset.extensions import cache_manager
+from superset.models.cache import CacheKey
+from superset.stats_logger import BaseStatsLogger
+from superset.utils.core import json_int_dttm_ser
+
+config = app.config # type: ignore
+stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
+logger = logging.getLogger(__name__)
+
+# TODO: DRY up cache key code
+def json_dumps(obj: Any, sort_keys: bool = False) -> str:
+ return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys)
+
+
+def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str:
+ json_data = json_dumps(values_dict, sort_keys=True)
+ hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest()
+ return f"{key_prefix}{hash_str}"
+
+
+def set_and_log_cache(
+ cache_instance: Cache,
+ cache_key: str,
+ cache_value: Dict[str, Any],
+ cache_timeout: Optional[int] = None,
+ datasource_uid: Optional[str] = None,
+) -> None:
+ timeout = cache_timeout if cache_timeout else config["CACHE_DEFAULT_TIMEOUT"]
+ try:
+ dttm = datetime.utcnow().isoformat().split(".")[0]
+ value = {**cache_value, "dttm": dttm}
+ cache_instance.set(cache_key, value, timeout=timeout)
+ stats_logger.incr("set_cache_key")
+
+ if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
+ ck = CacheKey(
+ cache_key=cache_key,
+ cache_timeout=cache_timeout,
+ datasource_uid=datasource_uid,
+ )
+ db.session.add(ck)
+ except Exception as ex: # pylint: disable=broad-except
+ # cache.set call can fail if the backend is down or if
+ # the key is too large or whatever other reasons
+ logger.warning("Could not cache key %s", cache_key)
+ logger.exception(ex)
+
# If a user sets `max_age` to 0, for long the browser should cache the
# resource? Flask-Caching will cache forever, but for the HTTP header we need
diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py
index 84ee9b7ffb..5d4c2013dd 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -238,7 +238,7 @@ def pivot( # pylint: disable=too-many-arguments
Default to 'All'.
:param flatten_columns: Convert column names to strings
:return: A pivot table
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
@@ -293,7 +293,7 @@ def aggregate(
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
@@ -313,7 +313,7 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
@@ -348,7 +348,7 @@ def rolling( # pylint: disable=too-many-arguments
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:return: DataFrame with the rolling columns
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
@@ -408,7 +408,7 @@ def select(
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
df_select = df.copy(deep=False)
if columns:
@@ -433,7 +433,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
unchanged.
:param periods: periods to shift for calculating difference.
:return: DataFrame with diffed columns
- :raises ChartDataValidationError: If the request in incorrect
+ :raises QueryObjectValidationError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods)
diff --git a/superset/views/api.py b/superset/views/api.py
index a5090b31c9..1b19455126 100644
--- a/superset/views/api.py
+++ b/superset/views/api.py
@@ -44,7 +44,8 @@ class Api(BaseSupersetView):
"""
query_context = QueryContext(**json.loads(request.form["query_context"]))
query_context.raise_for_access()
- payload_json = query_context.get_payload()
+ result = query_context.get_payload()
+ payload_json = result["queries"]
return json.dumps(
payload_json, default=utils.json_int_dttm_ser, ignore_nan=True
)
diff --git a/superset/views/base.py b/superset/views/base.py
index 2f87a81f97..9ed65222bc 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -77,6 +77,8 @@ FRONTEND_CONF_KEYS = (
"SUPERSET_WEBSERVER_DOMAINS",
"SQLLAB_SAVE_WARNING_MESSAGE",
"DISPLAY_MAX_ROW",
+ "GLOBAL_ASYNC_QUERIES_TRANSPORT",
+ "GLOBAL_ASYNC_QUERIES_POLLING_DELAY",
)
logger = logging.getLogger(__name__)
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 0f05f07ee1..2e0d3b2369 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -126,6 +126,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
method_permission_name = {
"bulk_delete": "delete",
"data": "list",
+ "data_from_cache": "list",
"delete": "delete",
"distinct": "list",
"export": "mulexport",
diff --git a/superset/views/core.py b/superset/views/core.py
index 4e062da974..63591b73ef 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -28,7 +28,11 @@ import simplejson as json
from flask import abort, flash, g, Markup, redirect, render_template, request, Response
from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
-from flask_appbuilder.security.decorators import has_access, has_access_api
+from flask_appbuilder.security.decorators import (
+ has_access,
+ has_access_api,
+ permission_name,
+)
from flask_appbuilder.security.sqla import models as ab_models
from flask_babel import gettext as __, lazy_gettext as _
from jinja2.exceptions import TemplateError
@@ -66,6 +70,7 @@ from superset.dashboards.dao import DashboardDAO
from superset.databases.dao import DatabaseDAO
from superset.databases.filters import DatabaseFilter
from superset.exceptions import (
+ CacheLoadError,
CertificateException,
DatabaseNotFound,
SerializationError,
@@ -73,6 +78,7 @@ from superset.exceptions import (
SupersetSecurityException,
SupersetTimeoutException,
)
+from superset.extensions import async_query_manager, cache_manager
from superset.jinja_context import get_template_processor
from superset.models.core import Database, FavStar, Log
from superset.models.dashboard import Dashboard
@@ -87,8 +93,10 @@ from superset.security.analytics_db_safety import (
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
+from superset.tasks.async_queries import load_explore_json_into_cache
from superset.typing import FlaskResponse
from superset.utils import core as utils
+from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.dates import now_as_float
from superset.views.base import (
@@ -113,6 +121,7 @@ from superset.views.utils import (
apply_display_max_row_limit,
bootstrap_user_data,
check_datasource_perms,
+ check_explore_cache_perms,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
@@ -484,6 +493,43 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
payload = viz_obj.get_payload()
return data_payload_response(*viz_obj.payload_json_and_has_error(payload))
+ @event_logger.log_this
+ @api
+ @has_access_api
+ @handle_api_exception
+ @permission_name("explore_json")
+ @expose("/explore_json/data/", methods=["GET"])
+ @etag_cache(check_perms=check_explore_cache_perms)
+ def explore_json_data(self, cache_key: str) -> FlaskResponse:
+ """Serves cached result data for async explore_json calls
+
+ `self.generate_json` receives this input and returns different
+ payloads based on the request args in the first block
+
+ TODO: form_data should not be loaded twice from cache
+ (also loaded in `check_explore_cache_perms`)
+ """
+ try:
+ cached = cache_manager.cache.get(cache_key)
+ if not cached:
+ raise CacheLoadError("Cached data not found")
+
+ form_data = cached.get("form_data")
+ response_type = cached.get("response_type")
+
+ datasource_id, datasource_type = get_datasource_info(None, None, form_data)
+
+ viz_obj = get_viz(
+ datasource_type=cast(str, datasource_type),
+ datasource_id=datasource_id,
+ form_data=form_data,
+ force_cached=True,
+ )
+
+ return self.generate_json(viz_obj, response_type)
+ except SupersetException as ex:
+ return json_error_response(utils.error_msg_from_exception(ex), 400)
+
EXPLORE_JSON_METHODS = ["POST"]
if not is_feature_enabled("ENABLE_EXPLORE_JSON_CSRF_PROTECTION"):
EXPLORE_JSON_METHODS.append("GET")
@@ -528,11 +574,31 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource_id, datasource_type, form_data
)
+ force = request.args.get("force") == "true"
+
+ # TODO: support CSV, SQL query and other non-JSON types
+ if (
+ is_feature_enabled("GLOBAL_ASYNC_QUERIES")
+ and response_type == utils.ChartDataResultFormat.JSON
+ ):
+ try:
+ async_channel_id = async_query_manager.parse_jwt_from_request(
+ request
+ )["channel"]
+ job_metadata = async_query_manager.init_job(async_channel_id)
+ load_explore_json_into_cache.delay(
+ job_metadata, form_data, response_type, force
+ )
+ except AsyncQueryTokenException:
+ return json_error_response("Not authorized", 401)
+
+ return json_success(json.dumps(job_metadata), status=202)
+
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
- force=request.args.get("force") == "true",
+ force=force,
)
return self.generate_json(viz_obj, response_type)
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 08de168003..28104aa86a 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -34,10 +34,12 @@ from superset import app, dataframe, db, is_feature_enabled, result_set
from superset.connectors.connector_registry import ConnectorRegistry
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
+ CacheLoadError,
SerializationError,
SupersetException,
SupersetSecurityException,
)
+from superset.extensions import cache_manager
from superset.legacy import update_time_range
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@@ -108,13 +110,19 @@ def get_permissions(
def get_viz(
- form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
+ form_data: FormData,
+ datasource_type: str,
+ datasource_id: int,
+ force: bool = False,
+ force_cached: bool = False,
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
- viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
+ viz_obj = viz.viz_types[viz_type](
+ datasource, form_data=form_data, force=force, force_cached=force_cached
+ )
return viz_obj
@@ -422,10 +430,26 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
return obj and user in obj.owners
+def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
+ """
+ Loads async explore_json request data from cache and performs access check
+
+ :param _self: the Superset view instance
+ :param cache_key: the cache key passed into /explore_json/data/
+ :raises SupersetSecurityException: If the user cannot access the resource
+ """
+ cached = cache_manager.cache.get(cache_key)
+ if not cached:
+ raise CacheLoadError("Cached data not found")
+
+ check_datasource_perms(_self, form_data=cached["form_data"])
+
+
def check_datasource_perms(
_self: Any,
datasource_type: Optional[str] = None,
datasource_id: Optional[int] = None,
+ **kwargs: Any
) -> None:
"""
Check if user can access a cached response from explore_json.
@@ -438,7 +462,7 @@ def check_datasource_perms(
:raises SupersetSecurityException: If the user cannot access the resource
"""
- form_data = get_form_data()[0]
+ form_data = kwargs["form_data"] if "form_data" in kwargs else get_form_data()[0]
try:
datasource_id, datasource_type = get_datasource_info(
diff --git a/superset/viz.py b/superset/viz.py
index de5d597f73..9b7d46ac40 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -38,6 +38,7 @@ from typing import (
Optional,
Set,
Tuple,
+ Type,
TYPE_CHECKING,
Union,
)
@@ -57,6 +58,7 @@ from superset import app, db, is_feature_enabled
from superset.constants import NULL_STRING
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
+ CacheLoadError,
NullValueException,
QueryObjectValidationError,
SpatialException,
@@ -66,6 +68,7 @@ from superset.models.cache import CacheKey
from superset.models.helpers import QueryResult
from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
+from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
DTTM_ALIAS,
JS_MAX_INTEGER,
@@ -97,37 +100,6 @@ METRIC_KEYS = [
]
-def set_and_log_cache(
- cache_key: str,
- df: pd.DataFrame,
- query: str,
- cached_dttm: str,
- cache_timeout: int,
- datasource_uid: Optional[str],
- annotation_data: Optional[Dict[str, Any]] = None,
-) -> None:
- try:
- cache_value = dict(
- dttm=cached_dttm, df=df, query=query, annotation_data=annotation_data or {}
- )
- stats_logger.incr("set_cache_key")
- cache_manager.data_cache.set(cache_key, cache_value, timeout=cache_timeout)
-
- if datasource_uid:
- ck = CacheKey(
- cache_key=cache_key,
- cache_timeout=cache_timeout,
- datasource_uid=datasource_uid,
- )
- db.session.add(ck)
- except Exception as ex:
- # cache.set call can fail if the backend is down or if
- # the key is too large or whatever other reasons
- logger.warning("Could not cache key {}".format(cache_key))
- logger.exception(ex)
- cache_manager.data_cache.delete(cache_key)
-
-
class BaseViz:
"""All visualizations derive this base class"""
@@ -144,6 +116,7 @@ class BaseViz:
datasource: "BaseDatasource",
form_data: Dict[str, Any],
force: bool = False,
+ force_cached: bool = False,
) -> None:
if not datasource:
raise QueryObjectValidationError(_("Viz is missing a datasource"))
@@ -164,6 +137,7 @@ class BaseViz:
self.results: Optional[QueryResult] = None
self.errors: List[Dict[str, Any]] = []
self.force = force
+ self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None
@@ -180,6 +154,10 @@ class BaseViz:
self.applied_filters: List[Dict[str, str]] = []
self.rejected_filters: List[Dict[str, str]] = []
+ @property
+ def force_cached(self) -> bool:
+ return self._force_cached
+
def process_metrics(self) -> None:
# metrics in TableViz is order sensitive, so metric_dict should be
# OrderedDict
@@ -272,7 +250,7 @@ class BaseViz:
"columns": [o.column_name for o in self.datasource.columns],
}
)
- df = self.get_df(query_obj)
+ df = self.get_df_payload(query_obj)["df"] # leverage caching logic
return df.to_dict(orient="records")
def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
@@ -518,7 +496,6 @@ class BaseViz:
is_loaded = False
stacktrace = None
df = None
- cached_dttm = datetime.utcnow().isoformat().split(".")[0]
if cache_key and cache_manager.data_cache and not self.force:
cache_value = cache_manager.data_cache.get(cache_key)
if cache_value:
@@ -539,6 +516,11 @@ class BaseViz:
logger.info("Serving from cache")
if query_obj and not is_loaded:
+ if self.force_cached:
+ logger.warning(
+ f"force_cached (viz.py): value not found for cache key {cache_key}"
+ )
+ raise CacheLoadError(_("Cached value not found"))
try:
invalid_columns = [
col
@@ -590,12 +572,11 @@ class BaseViz:
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
set_and_log_cache(
- cache_key=cache_key,
- df=df,
- query=self.query,
- cached_dttm=cached_dttm,
- cache_timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
+ cache_manager.data_cache,
+ cache_key,
+ {"df": df, "query": self.query},
+ self.cache_timeout,
+ self.datasource.uid,
)
return {
"cache_key": self._any_cache_key,
@@ -618,13 +599,15 @@ class BaseViz:
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
- def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
- has_error = (
+ def has_error(self, payload: VizPayload) -> bool:
+ return (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
or bool(payload.get("errors"))
)
- return self.json_dumps(payload), has_error
+
+ def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
+ return self.json_dumps(payload), self.has_error(payload)
@property
def data(self) -> Dict[str, Any]:
@@ -638,7 +621,7 @@ class BaseViz:
return content
def get_csv(self) -> Optional[str]:
- df = self.get_df()
+ df = self.get_df_payload()["df"] # leverage caching logic
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
@@ -1721,7 +1704,7 @@ class SunburstViz(BaseViz):
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
- fd = self.form_data
+ fd = copy.deepcopy(self.form_data)
cols = fd.get("groupby") or []
cols.extend(["m1", "m2"])
metric = utils.get_metric_name(fd["metric"])
@@ -2983,12 +2966,14 @@ class PartitionViz(NVD3TimeSeriesViz):
return self.nest_values(levels)
+def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]:
+ return set(cls.__subclasses__()).union(
+ [sc for c in cls.__subclasses__() for sc in get_subclasses(c)]
+ )
+
+
viz_types = {
o.viz_type: o
- for o in globals().values()
- if (
- inspect.isclass(o)
- and issubclass(o, BaseViz)
- and o.viz_type not in config["VIZ_TYPE_DENYLIST"]
- )
+ for o in get_subclasses(BaseViz)
+ if o.viz_type not in config["VIZ_TYPE_DENYLIST"]
}
diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py
index 600f44141a..798ce42f27 100644
--- a/superset/viz_sip38.py
+++ b/superset/viz_sip38.py
@@ -57,13 +57,13 @@ from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
+from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
DTTM_ALIAS,
JS_MAX_INTEGER,
merge_extra_filters,
to_adhoc,
)
-from superset.viz import set_and_log_cache
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@@ -518,10 +518,9 @@ class BaseViz:
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
set_and_log_cache(
+ cache_manager.data_cache,
cache_key,
- df,
- self.query,
- cached_dttm,
+ {"df": df, "query": self.query},
self.cache_timeout,
self.datasource.uid,
)
diff --git a/tests/async_events/__init__.py b/tests/async_events/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/async_events/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/async_events/api_tests.py b/tests/async_events/api_tests.py
new file mode 100644
index 0000000000..04d838b97b
--- /dev/null
+++ b/tests/async_events/api_tests.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+from typing import Optional
+from unittest import mock
+
+from superset.extensions import async_query_manager
+from tests.base_tests import SupersetTestCase
+from tests.test_app import app
+
+
+class TestAsyncEventApi(SupersetTestCase):
+ UUID = "943c920-32a5-412a-977d-b8e47d36f5a4"
+
+ def fetch_events(self, last_id: Optional[str] = None):
+ base_uri = "api/v1/async_event/"
+ uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri
+ return self.client.get(uri)
+
+ @mock.patch("uuid.uuid4", return_value=UUID)
+ def test_events(self, mock_uuid4):
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
+ rv = self.fetch_events()
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 200
+ channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
+ mock_xrange.assert_called_with(channel_id, "-", "+", 100)
+ self.assertEqual(response, {"result": []})
+
+ @mock.patch("uuid.uuid4", return_value=UUID)
+ def test_events_last_id(self, mock_uuid4):
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
+ rv = self.fetch_events("1607471525180-0")
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 200
+ channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
+ mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
+ self.assertEqual(response, {"result": []})
+
+ @mock.patch("uuid.uuid4", return_value=UUID)
+ def test_events_results(self, mock_uuid4):
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
+ mock_xrange.return_value = [
+ (
+ "1607477697866-0",
+ {
+ "data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463"}'
+ },
+ ),
+ (
+ "1607477697993-0",
+ {
+ "data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d"}'
+ },
+ ),
+ ]
+ rv = self.fetch_events()
+ response = json.loads(rv.data.decode("utf-8"))
+
+ assert rv.status_code == 200
+ channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
+ mock_xrange.assert_called_with(channel_id, "-", "+", 100)
+ expected = {
+ "result": [
+ {
+ "channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
+ "errors": [],
+ "id": "1607477697866-0",
+ "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512",
+ "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463",
+ "status": "done",
+ "user_id": "1",
+ },
+ {
+ "channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
+ "errors": [],
+ "id": "1607477697993-0",
+ "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c",
+ "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d",
+ "status": "done",
+ "user_id": "1",
+ },
+ ]
+ }
+ self.assertEqual(response, expected)
+
+ def test_events_no_login(self):
+ async_query_manager.init_app(app)
+ rv = self.fetch_events()
+ assert rv.status_code == 401
+
+ def test_events_no_token(self):
+ self.login(username="admin")
+ self.client.set_cookie(
+ "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], ""
+ )
+ rv = self.fetch_events()
+ assert rv.status_code == 401
diff --git a/tests/cache_tests.py b/tests/cache_tests.py
index 0b887622ef..3ffd52a378 100644
--- a/tests/cache_tests.py
+++ b/tests/cache_tests.py
@@ -35,6 +35,7 @@ class TestCache(SupersetTestCase):
cache_manager.data_cache.clear()
def test_no_data_cache(self):
+ data_cache_config = app.config["DATA_CACHE_CONFIG"]
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
cache_manager.init_app(app)
@@ -48,11 +49,15 @@ class TestCache(SupersetTestCase):
resp_from_cache = self.get_json_resp(
json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
)
+ # restore DATA_CACHE_CONFIG
+ app.config["DATA_CACHE_CONFIG"] = data_cache_config
self.assertFalse(resp["is_cached"])
self.assertFalse(resp_from_cache["is_cached"])
def test_slice_data_cache(self):
# Override cache config
+ data_cache_config = app.config["DATA_CACHE_CONFIG"]
+ cache_default_timeout = app.config["CACHE_DEFAULT_TIMEOUT"]
app.config["CACHE_DEFAULT_TIMEOUT"] = 100
app.config["DATA_CACHE_CONFIG"] = {
"CACHE_TYPE": "simple",
@@ -87,5 +92,6 @@ class TestCache(SupersetTestCase):
self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"]))
# reset cache config
- app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
+ app.config["DATA_CACHE_CONFIG"] = data_cache_config
+ app.config["CACHE_DEFAULT_TIMEOUT"] = cache_default_timeout
cache_manager.init_app(app)
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 27c88886e1..092c354199 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -31,21 +31,21 @@ from sqlalchemy import and_
from sqlalchemy.sql import func
from tests.test_app import app
-from superset.connectors.sqla.models import SqlaTable
-from superset.utils.core import AnnotationType, get_example_database
-from tests.fixtures.energy_dashboard import load_energy_table_with_slice
-from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
+from superset.charts.commands.data import ChartDataCommand
from superset.connectors.connector_registry import ConnectorRegistry
-from superset.extensions import db, security_manager
+from superset.connectors.sqla.models import SqlaTable
+from superset.extensions import async_query_manager, cache_manager, db, security_manager
from superset.models.annotations import AnnotationLayer
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils import core as utils
+from superset.utils.core import AnnotationType, get_example_database
from tests.base_api_tests import ApiOwnersTestCaseMixin
-from tests.base_tests import SupersetTestCase
+from tests.base_tests import SupersetTestCase, post_assert_metric, test_client
+
from tests.fixtures.importexport import (
chart_config,
chart_metadata_config,
@@ -53,6 +53,7 @@ from tests.fixtures.importexport import (
dataset_config,
dataset_metadata_config,
)
+from tests.fixtures.energy_dashboard import load_energy_table_with_slice
from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from tests.annotation_layers.fixtures import create_annotation_layers
@@ -99,6 +100,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
db.session.commit()
return slice
+ @pytest.fixture(autouse=True)
+ def clear_data_cache(self):
+ with app.app_context():
+ cache_manager.data_cache.clear()
+ yield
+
@pytest.fixture()
def create_charts(self):
with self.create_app().app_context():
@@ -1287,6 +1294,155 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_chart_data_async(self):
+ """
+ Chart data API: Test chart data query (async)
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ request_payload = get_query_context(table.name, table.id, table.type)
+ rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+ self.assertEqual(rv.status_code, 202)
+ data = json.loads(rv.data.decode("utf-8"))
+ keys = list(data.keys())
+ self.assertCountEqual(
+ keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
+ )
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_chart_data_async_results_type(self):
+ """
+ Chart data API: Test chart data query non-JSON format (async)
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ request_payload = get_query_context(table.name, table.id, table.type)
+ request_payload["result_type"] = "results"
+ rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+ self.assertEqual(rv.status_code, 200)
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_chart_data_async_invalid_token(self):
+ """
+ Chart data API: Test chart data query (async)
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ request_payload = get_query_context(table.name, table.id, table.type)
+ test_client.set_cookie(
+ "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
+ )
+ rv = post_assert_metric(test_client, CHART_DATA_URI, request_payload, "data")
+ self.assertEqual(rv.status_code, 401)
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ @mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
+ def test_chart_data_cache(self, load_qc_mock):
+ """
+ Chart data cache API: Test chart data async cache request
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ query_context = get_query_context(table.name, table.id, table.type)
+ load_qc_mock.return_value = query_context
+ orig_run = ChartDataCommand.run
+
+ def mock_run(self, **kwargs):
+ assert kwargs["force_cached"] == True
+ # override force_cached to get result from DB
+ return orig_run(self, force_cached=False)
+
+ with mock.patch.object(ChartDataCommand, "run", new=mock_run):
+ rv = self.get_assert_metric(
+ f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
+ )
+ data = json.loads(rv.data.decode("utf-8"))
+
+ self.assertEqual(rv.status_code, 200)
+ self.assertEqual(data["result"][0]["rowcount"], 45)
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ @mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
+ def test_chart_data_cache_run_failed(self, load_qc_mock):
+ """
+ Chart data cache API: Test chart data async cache request with run failure
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ table = self.get_table_by_name("birth_names")
+ query_context = get_query_context(table.name, table.id, table.type)
+ load_qc_mock.return_value = query_context
+ rv = self.get_assert_metric(
+ f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
+ )
+ data = json.loads(rv.data.decode("utf-8"))
+
+ self.assertEqual(rv.status_code, 422)
+ self.assertEqual(data["message"], "Error loading data from cache")
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ @mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
+ def test_chart_data_cache_no_login(self, load_qc_mock):
+ """
+ Chart data cache API: Test chart data async cache request (no login)
+ """
+ async_query_manager.init_app(app)
+ table = self.get_table_by_name("birth_names")
+ query_context = get_query_context(table.name, table.id, table.type)
+ load_qc_mock.return_value = query_context
+ orig_run = ChartDataCommand.run
+
+ def mock_run(self, **kwargs):
+ assert kwargs["force_cached"] == True
+ # override force_cached to get result from DB
+ return orig_run(self, force_cached=False)
+
+ with mock.patch.object(ChartDataCommand, "run", new=mock_run):
+ rv = self.get_assert_metric(
+ f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
+ )
+
+ self.assertEqual(rv.status_code, 401)
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_chart_data_cache_key_error(self):
+ """
+ Chart data cache API: Test chart data async cache request with invalid cache key
+ """
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ rv = self.get_assert_metric(
+ f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
+ )
+
+ self.assertEqual(rv.status_code, 404)
+
def test_export_chart(self):
"""
Chart API: Test export chart
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 6ac990ba90..e011fd69ec 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -52,6 +52,7 @@ from superset import (
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
+from superset.extensions import async_query_manager
from superset.models import core as models
from superset.models.annotations import Annotation, AnnotationLayer
from superset.models.dashboard import Dashboard
@@ -602,10 +603,13 @@ class TestCore(SupersetTestCase):
) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
def test_cache_logging(self):
+ store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
+ app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
girls_slice = self.get_slice("Girls", db.session)
self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id))
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
assert ck.datasource_uid == f"{girls_slice.table.id}__table"
+ app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
def test_shortner(self):
self.login(username="admin")
@@ -841,6 +845,191 @@ class TestCore(SupersetTestCase):
"The datasource associated with this chart no longer exists",
)
+ def test_explore_json(self):
+ tbl_id = self.table_ids.get("birth_names")
+ form_data = {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{tbl_id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ self.login(username="admin")
+ rv = self.client.post(
+ "/superset/explore_json/", data={"form_data": json.dumps(form_data)},
+ )
+ data = json.loads(rv.data.decode("utf-8"))
+
+ self.assertEqual(rv.status_code, 200)
+ self.assertEqual(data["rowcount"], 2)
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_explore_json_async(self):
+ tbl_id = self.table_ids.get("birth_names")
+ form_data = {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{tbl_id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ rv = self.client.post(
+ "/superset/explore_json/", data={"form_data": json.dumps(form_data)},
+ )
+ data = json.loads(rv.data.decode("utf-8"))
+ keys = list(data.keys())
+
+ self.assertEqual(rv.status_code, 202)
+ self.assertCountEqual(
+ keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
+ )
+
+ @mock.patch.dict(
+ "superset.extensions.feature_flag_manager._feature_flags",
+ GLOBAL_ASYNC_QUERIES=True,
+ )
+ def test_explore_json_async_results_format(self):
+ tbl_id = self.table_ids.get("birth_names")
+ form_data = {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{tbl_id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ async_query_manager.init_app(app)
+ self.login(username="admin")
+ rv = self.client.post(
+ "/superset/explore_json/?results=true",
+ data={"form_data": json.dumps(form_data)},
+ )
+ self.assertEqual(rv.status_code, 200)
+
+ @mock.patch(
+ "superset.utils.cache_manager.CacheManager.cache",
+ new_callable=mock.PropertyMock,
+ )
+ @mock.patch("superset.viz.BaseViz.force_cached", new_callable=mock.PropertyMock)
+ def test_explore_json_data(self, mock_force_cached, mock_cache):
+ tbl_id = self.table_ids.get("birth_names")
+ form_data = dict(
+ {
+ "form_data": {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{tbl_id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ }
+ )
+
+ class MockCache:
+ def get(self, key):
+ return form_data
+
+ def set(self):
+ return None
+
+ mock_cache.return_value = MockCache()
+ mock_force_cached.return_value = False
+
+ self.login(username="admin")
+ rv = self.client.get("/superset/explore_json/data/valid-cache-key")
+ data = json.loads(rv.data.decode("utf-8"))
+
+ self.assertEqual(rv.status_code, 200)
+ self.assertEqual(data["rowcount"], 2)
+
+ @mock.patch(
+ "superset.utils.cache_manager.CacheManager.cache",
+ new_callable=mock.PropertyMock,
+ )
+ def test_explore_json_data_no_login(self, mock_cache):
+ tbl_id = self.table_ids.get("birth_names")
+ form_data = dict(
+ {
+ "form_data": {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{tbl_id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ }
+ )
+
+ class MockCache:
+ def get(self, key):
+ return form_data
+
+ def set(self):
+ return None
+
+ mock_cache.return_value = MockCache()
+
+ rv = self.client.get("/superset/explore_json/data/valid-cache-key")
+ self.assertEqual(rv.status_code, 401)
+
+ def test_explore_json_data_invalid_cache_key(self):
+ self.login(username="admin")
+ cache_key = "invalid-cache-key"
+ rv = self.client.get(f"/superset/explore_json/data/{cache_key}")
+ data = json.loads(rv.data.decode("utf-8"))
+
+ self.assertEqual(rv.status_code, 404)
+ self.assertEqual(data["error"], "Cached data not found")
+
@mock.patch(
"superset.security.SupersetSecurityManager.get_schemas_accessible_by_user"
)
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index 3833ee0522..5bccf07507 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -82,7 +82,7 @@ class TestQueryContext(SupersetTestCase):
# construct baseline cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- cache_key_original = query_context.cache_key(query_object)
+ cache_key_original = query_context.query_cache_key(query_object)
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
@@ -99,7 +99,7 @@ class TestQueryContext(SupersetTestCase):
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- cache_key_new = query_context.cache_key(query_object)
+ cache_key_new = query_context.query_cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
@@ -115,20 +115,20 @@ class TestQueryContext(SupersetTestCase):
# construct baseline cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- cache_key_original = query_context.cache_key(query_object)
+ cache_key_original = query_context.query_cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- cache_key_with_null = query_context.cache_key(query_object)
+ cache_key_with_null = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
- cache_key_without_post_processing = query_context.cache_key(query_object)
+ cache_key_without_post_processing = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
def test_query_context_time_range_endpoints(self):
@@ -179,13 +179,10 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
- data = responses[0]["data"]
+ data = responses["queries"][0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
- ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
- assert ck.datasource_uid == f"{table.id}__table"
-
def test_sql_injection_via_groupby(self):
"""
Ensure that calling invalid columns names in groupby are caught
@@ -197,7 +194,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["groupby"] = ["currentDatabase()"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
- assert query_payload[0].get("error") is not None
+ assert query_payload["queries"][0].get("error") is not None
def test_sql_injection_via_columns(self):
"""
@@ -212,7 +209,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["columns"] = ["*, 'extra'"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
- assert query_payload[0].get("error") is not None
+ assert query_payload["queries"][0].get("error") is not None
def test_sql_injection_via_metrics(self):
"""
@@ -233,7 +230,7 @@ class TestQueryContext(SupersetTestCase):
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
- assert query_payload[0].get("error") is not None
+ assert query_payload["queries"][0].get("error") is not None
def test_samples_response_type(self):
"""
@@ -248,7 +245,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
- data = responses[0]["data"]
+ data = responses["queries"][0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])
@@ -265,7 +262,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
- response = responses[0]
+ response = responses["queries"][0]
self.assertEqual(len(response), 2)
self.assertEqual(response["language"], "sql")
self.assertIn("SELECT", response["query"])
diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py
index 88a79d378d..c23d2007ef 100644
--- a/tests/superset_test_config.py
+++ b/tests/superset_test_config.py
@@ -22,7 +22,7 @@ from tests.superset_test_custom_template_processors import CustomPrestoTemplateP
AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
-DEBUG = True
+DEBUG = False
SUPERSET_WEBSERVER_PORT = 8081
# Allowing SQLALCHEMY_DATABASE_URI and SQLALCHEMY_EXAMPLES_URI to be defined as an env vars for
@@ -96,6 +96,8 @@ DATA_CACHE_CONFIG = {
"CACHE_KEY_PREFIX": "superset_data_cache",
}
+GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me-test-secret-change-me"
+
class CeleryConfig(object):
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
diff --git a/tests/tasks/__init__.py b/tests/tasks/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/tasks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py
new file mode 100644
index 0000000000..6fe2e7c319
--- /dev/null
+++ b/tests/tasks/async_queries_tests.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for async query celery jobs in Superset"""
+import re
+from unittest import mock
+from uuid import uuid4
+
+import pytest
+
+from superset import db
+from superset.charts.commands.data import ChartDataCommand
+from superset.charts.commands.exceptions import ChartDataQueryFailedError
+from superset.connectors.sqla.models import SqlaTable
+from superset.exceptions import SupersetException
+from superset.extensions import async_query_manager
+from superset.tasks.async_queries import (
+ load_chart_data_into_cache,
+ load_explore_json_into_cache,
+)
+from tests.base_tests import SupersetTestCase
+from tests.fixtures.query_context import get_query_context
+from tests.test_app import app
+
+
+def get_table_by_name(name: str) -> SqlaTable:
+ with app.app_context():
+ return db.session.query(SqlaTable).filter_by(table_name=name).one()
+
+
+class TestAsyncQueries(SupersetTestCase):
+ @mock.patch.object(async_query_manager, "update_job")
+ def test_load_chart_data_into_cache(self, mock_update_job):
+ async_query_manager.init_app(app)
+ table = get_table_by_name("birth_names")
+ form_data = get_query_context(table.name, table.id, table.type)
+ job_metadata = {
+ "channel_id": str(uuid4()),
+ "job_id": str(uuid4()),
+ "user_id": 1,
+ "status": "pending",
+ "errors": [],
+ }
+
+ load_chart_data_into_cache(job_metadata, form_data)
+
+ mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
+
+ @mock.patch.object(
+ ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
+ )
+ @mock.patch.object(async_query_manager, "update_job")
+ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
+ async_query_manager.init_app(app)
+ table = get_table_by_name("birth_names")
+ form_data = get_query_context(table.name, table.id, table.type)
+ job_metadata = {
+ "channel_id": str(uuid4()),
+ "job_id": str(uuid4()),
+ "user_id": 1,
+ "status": "pending",
+ "errors": [],
+ }
+ with pytest.raises(ChartDataQueryFailedError):
+ load_chart_data_into_cache(job_metadata, form_data)
+
+ mock_run_command.assert_called_with(cache=True)
+ errors = [{"message": "Error: foo"}]
+ mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
+
+ @mock.patch.object(async_query_manager, "update_job")
+ def test_load_explore_json_into_cache(self, mock_update_job):
+ async_query_manager.init_app(app)
+ table = get_table_by_name("birth_names")
+ form_data = {
+ "queryFields": {
+ "metrics": "metrics",
+ "groupby": "groupby",
+ "columns": "groupby",
+ },
+ "datasource": f"{table.id}__table",
+ "viz_type": "dist_bar",
+ "time_range_endpoints": ["inclusive", "exclusive"],
+ "granularity_sqla": "ds",
+ "time_range": "No filter",
+ "metrics": ["count"],
+ "adhoc_filters": [],
+ "groupby": ["gender"],
+ "row_limit": 100,
+ }
+ job_metadata = {
+ "channel_id": str(uuid4()),
+ "job_id": str(uuid4()),
+ "user_id": 1,
+ "status": "pending",
+ "errors": [],
+ }
+
+ load_explore_json_into_cache(job_metadata, form_data)
+
+ mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
+
+ @mock.patch.object(async_query_manager, "update_job")
+ def test_load_explore_json_into_cache_error(self, mock_update_job):
+ async_query_manager.init_app(app)
+ form_data = {}
+ job_metadata = {
+ "channel_id": str(uuid4()),
+ "job_id": str(uuid4()),
+ "user_id": 1,
+ "status": "pending",
+ "errors": [],
+ }
+
+ with pytest.raises(SupersetException):
+ load_explore_json_into_cache(job_metadata, form_data)
+
+ errors = ["The datasource associated with this chart no longer exists"]
+ mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 1dffdcd2ad..09fd3a7c91 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -163,9 +163,20 @@ class TestBaseViz(SupersetTestCase):
datasource.database.cache_timeout = 1666
self.assertEqual(1666, test_viz.cache_timeout)
+ datasource.database.cache_timeout = None
+ test_viz = viz.BaseViz(datasource, form_data={})
+ self.assertEqual(
+ app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"],
+ test_viz.cache_timeout,
+ )
+
+ data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
+ app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout)
+ # restore DATA_CACHE_CONFIG timeout
+ app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
class TestTableViz(SupersetTestCase):