feat(SIP-39): Async query support for charts (#11499)

* Generate JWT in Flask app

* Refactor chart data API query logic, add JWT validation and async worker

* Add redis stream implementation, refactoring

* Add chart data cache endpoint, refactor QueryContext caching

* Typing, linting, refactoring

* pytest fixes and openapi schema update

* Enforce caching be configured for async query init

* Async query processing for explore_json endpoint

* Add /api/v1/async_event endpoint

* Async frontend for dashboards [WIP]

* Chart async error message support, refactoring

* Abstract asyncEvent middleware

* Async chart loading for Explore

* Pylint fixes

* asyncEvent middleware -> TypeScript, JS linting

* Chart data API: enforce forced_cache, add tests

* Add tests for explore_json endpoints

* Add test for chart data cache enpoint (no login)

* Consolidate set_and_log_cache and add STORE_CACHE_KEYS_IN_METADATA_DB flag

* Add tests for tasks/async_queries and address PR comments

* Bypass non-JSON result formats for async queries

* Add tests for redux middleware

* Remove debug statement

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>

* Skip force_cached if no queryObj

* SunburstViz: don't modify self.form_data

* Fix failing annotation test

* Resolve merge/lint issues

* Reduce polling delay

* Fix new getClientErrorObject reference

* Fix flakey unit tests

* /api/v1/async_event: increment redis stream ID, add tests

* PR feedback: refactoring, configuration

* Fixup: remove debugging

* Fix typescript errors due to redux upgrade

* Update UPDATING.md

* Fix failing py tests

* asyncEvent_spec.js -> asyncEvent_spec.ts

* Refactor flakey Python 3.7 mock assertions

* Fix another shared state issue in Py tests

* Use 'sub' claim in JWT for user_id

* Refactor async middleware config

* Fixup: restore FeatureFlag boolean type

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
Rob DiCiuccio 2020-12-10 20:21:56 -08:00 committed by GitHub
parent 0fdf026cbc
commit 4d329071a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 2219 additions and 197 deletions

View File

@ -623,6 +623,11 @@ cd superset-frontend
npm run test
```
To run a single test file:
```bash
npm run test -- path/to/file.js
```
### Integration Testing
We use [Cypress](https://www.cypress.io/) for integration tests. Tests can be run by `tox -e cypress`. To open Cypress and explore tests first setup and run test server:

View File

@ -23,7 +23,7 @@ This file documents any backwards-incompatible changes in Superset and
assists people when migrating to a new version.
## Next
- [11499](https://github.com/apache/incubator-superset/pull/11499): Breaking change: `STORE_CACHE_KEYS_IN_METADATA_DB` config flag added (default=`False`) to write `CacheKey` records to the metadata DB. `CacheKey` recording was enabled by default previously.
- [11920](https://github.com/apache/incubator-superset/pull/11920): Undos the DB migration from [11714](https://github.com/apache/incubator-superset/pull/11714) to prevent adding new columns to the logs table. Deploying a sha between these two PRs may result in locking your DB.
- [11704](https://github.com/apache/incubator-superset/pull/11704) Breaking change: Jinja templating for SQL queries has been updated, removing default modules such as `datetime` and `random` and enforcing static template values. To restore or extend functionality, use `JINJA_CONTEXT_ADDONS` and `CUSTOM_TEMPLATE_PROCESSORS` in `superset_config.py`.
- [11714](https://github.com/apache/incubator-superset/pull/11714): Logs

View File

@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,redis,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -0,0 +1,265 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import fetchMock from 'fetch-mock';
import sinon from 'sinon';
import * as featureFlags from 'src/featureFlags';
import initAsyncEvents from 'src/middleware/asyncEvent';
jest.useFakeTimers();
describe('asyncEvent middleware', () => {
const next = sinon.spy();
const state = {
charts: {
123: {
id: 123,
status: 'loading',
asyncJobId: 'foo123',
},
345: {
id: 345,
status: 'loading',
asyncJobId: 'foo345',
},
},
};
const events = [
{
status: 'done',
result_url: '/api/v1/chart/data/cache-key-1',
job_id: 'foo123',
channel_id: '999',
errors: [],
},
{
status: 'done',
result_url: '/api/v1/chart/data/cache-key-2',
job_id: 'foo345',
channel_id: '999',
errors: [],
},
];
const mockStore = {
getState: () => state,
dispatch: sinon.stub(),
};
const action = {
type: 'GENERIC_ACTION',
};
const EVENTS_ENDPOINT = 'glob:*/api/v1/async_event/*';
const CACHED_DATA_ENDPOINT = 'glob:*/api/v1/chart/data/*';
const config = {
GLOBAL_ASYNC_QUERIES_TRANSPORT: 'polling',
GLOBAL_ASYNC_QUERIES_POLLING_DELAY: 500,
};
let featureEnabledStub: any;
function setup() {
const getPendingComponents = sinon.stub();
const successAction = sinon.spy();
const errorAction = sinon.spy();
const testCallback = sinon.stub();
const testCallbackPromise = sinon.stub();
testCallbackPromise.returns(
new Promise(resolve => {
testCallback.callsFake(resolve);
}),
);
return {
getPendingComponents,
successAction,
errorAction,
testCallback,
testCallbackPromise,
};
}
beforeEach(() => {
fetchMock.get(EVENTS_ENDPOINT, {
status: 200,
body: { result: [] },
});
fetchMock.get(CACHED_DATA_ENDPOINT, {
status: 200,
body: { result: { some: 'data' } },
});
featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(true);
});
afterEach(() => {
fetchMock.reset();
next.resetHistory();
featureEnabledStub.restore();
});
afterAll(fetchMock.reset);
it('should initialize and call next', () => {
const { getPendingComponents, successAction, errorAction } = setup();
getPendingComponents.returns([]);
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
});
asyncEventMiddleware(mockStore)(next)(action);
expect(next.callCount).toBe(1);
});
it('should fetch events when there are pending components', () => {
const {
getPendingComponents,
successAction,
errorAction,
testCallback,
testCallbackPromise,
} = setup();
getPendingComponents.returns(Object.values(state.charts));
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
processEventsCallback: testCallback,
});
asyncEventMiddleware(mockStore)(next)(action);
return testCallbackPromise().then(() => {
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
});
});
it('should fetch cached when there are successful events', () => {
const {
getPendingComponents,
successAction,
errorAction,
testCallback,
testCallbackPromise,
} = setup();
fetchMock.reset();
fetchMock.get(EVENTS_ENDPOINT, {
status: 200,
body: { result: events },
});
fetchMock.get(CACHED_DATA_ENDPOINT, {
status: 200,
body: { result: { some: 'data' } },
});
getPendingComponents.returns(Object.values(state.charts));
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
processEventsCallback: testCallback,
});
asyncEventMiddleware(mockStore)(next)(action);
return testCallbackPromise().then(() => {
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
expect(successAction.callCount).toBe(2);
});
});
it('should call errorAction for cache fetch error responses', () => {
const {
getPendingComponents,
successAction,
errorAction,
testCallback,
testCallbackPromise,
} = setup();
fetchMock.reset();
fetchMock.get(EVENTS_ENDPOINT, {
status: 200,
body: { result: events },
});
fetchMock.get(CACHED_DATA_ENDPOINT, {
status: 400,
body: { errors: ['error'] },
});
getPendingComponents.returns(Object.values(state.charts));
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
processEventsCallback: testCallback,
});
asyncEventMiddleware(mockStore)(next)(action);
return testCallbackPromise().then(() => {
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
expect(errorAction.callCount).toBe(2);
});
});
it('should handle event fetching error responses', () => {
const {
getPendingComponents,
successAction,
errorAction,
testCallback,
testCallbackPromise,
} = setup();
fetchMock.reset();
fetchMock.get(EVENTS_ENDPOINT, {
status: 400,
body: { message: 'error' },
});
getPendingComponents.returns(Object.values(state.charts));
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
processEventsCallback: testCallback,
});
asyncEventMiddleware(mockStore)(next)(action);
return testCallbackPromise().then(() => {
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
});
});
it('should not fetch events when async queries are disabled', () => {
featureEnabledStub.restore();
featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(false);
const { getPendingComponents, successAction, errorAction } = setup();
getPendingComponents.returns(Object.values(state.charts));
const asyncEventMiddleware = initAsyncEvents({
config,
getPendingComponents,
successAction,
errorAction,
});
asyncEventMiddleware(mockStore)(next)(action);
expect(getPendingComponents.called).toBe(false);
});
});

View File

@ -17,7 +17,7 @@
* under the License.
*/
import { ErrorTypeEnum } from 'src/components/ErrorMessage/types';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
describe('getClientErrorObject()', () => {
it('Returns a Promise', () => {

View File

@ -30,7 +30,7 @@ import {
addSuccessToast as addSuccessToastAction,
addWarningToast as addWarningToastAction,
} from '../../messageToasts/actions/index';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
import COMMON_ERR_MESSAGES from '../../utils/errorMessages';
export const RESET_STATE = 'RESET_STATE';

View File

@ -25,7 +25,7 @@ import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
import Button from 'src/components/Button';
import CopyToClipboard from '../../components/CopyToClipboard';
import { storeQuery } from '../../utils/common';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
import withToasts from '../../messageToasts/enhancers/withToasts';
const propTypes = {

View File

@ -38,7 +38,7 @@ import {
import { addDangerToast } from '../messageToasts/actions';
import { logEvent } from '../logger/actions';
import { Logger, LOG_ACTIONS_LOAD_CHART } from '../logger/LogUtils';
import getClientErrorObject from '../utils/getClientErrorObject';
import { getClientErrorObject } from '../utils/getClientErrorObject';
import { allowCrossDomain as domainShardingEnabled } from '../utils/hostNamesConfig';
export const CHART_UPDATE_STARTED = 'CHART_UPDATE_STARTED';
@ -66,6 +66,11 @@ export function chartUpdateFailed(queryResponse, key) {
return { type: CHART_UPDATE_FAILED, queryResponse, key };
}
export const CHART_UPDATE_QUEUED = 'CHART_UPDATE_QUEUED';
export function chartUpdateQueued(asyncJobMeta, key) {
return { type: CHART_UPDATE_QUEUED, asyncJobMeta, key };
}
export const CHART_RENDERING_FAILED = 'CHART_RENDERING_FAILED';
export function chartRenderingFailed(error, key, stackTrace) {
return { type: CHART_RENDERING_FAILED, error, key, stackTrace };
@ -356,6 +361,12 @@ export function exploreJSON(
const chartDataRequestCaught = chartDataRequest
.then(response => {
if (isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES)) {
// deal with getChartDataRequest transforming the response data
const result = 'result' in response ? response.result[0] : response;
return dispatch(chartUpdateQueued(result, key));
}
// new API returns an object with an array of restults
// problem: response holds a list of results, when before we were just getting one result.
// How to make the entire app compatible with multiple results?

View File

@ -71,6 +71,14 @@ export default function chartReducer(charts = {}, action) {
chartUpdateEndTime: now(),
};
},
[actions.CHART_UPDATE_QUEUED](state) {
return {
...state,
asyncJobId: action.asyncJobMeta.job_id,
chartStatus: 'loading',
chartUpdateEndTime: now(),
};
},
[actions.CHART_RENDERING_SUCCEEDED](state) {
return { ...state, chartStatus: 'rendered', chartUpdateEndTime: now() };
},

View File

@ -21,7 +21,7 @@ import PropTypes from 'prop-types';
// TODO: refactor this with `import { AsyncSelect } from src/components/Select`
import { Select } from 'src/components/Select';
import { t, SupersetClient } from '@superset-ui/core';
import getClientErrorObject from '../utils/getClientErrorObject';
import { getClientErrorObject } from '../utils/getClientErrorObject';
const propTypes = {
dataEndpoint: PropTypes.string.isRequired,

View File

@ -29,7 +29,7 @@ import {
updateDirectPathToFilter,
} from './dashboardFilters';
import { applyDefaultFormData } from '../../explore/store';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
import { SAVE_TYPE_OVERWRITE } from '../util/constants';
import {
addSuccessToast,

View File

@ -17,7 +17,7 @@
* under the License.
*/
import { SupersetClient } from '@superset-ui/core';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
export const SET_DATASOURCE = 'SET_DATASOURCE';
export function setDatasource(datasource, key) {

View File

@ -22,7 +22,7 @@ import rison from 'rison';
import { addDangerToast } from 'src/messageToasts/actions';
import { getDatasourceParameter } from 'src/modules/utils';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
export const SET_ALL_SLICES = 'SET_ALL_SLICES';
export function setAllSlices(slices) {

View File

@ -35,7 +35,7 @@ import FormLabel from 'src/components/FormLabel';
import { JsonEditor } from 'src/components/AsyncAceEditor';
import ColorSchemeControlWrapper from 'src/dashboard/components/ColorSchemeControlWrapper';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
import withToasts from '../../messageToasts/enhancers/withToasts';
import '../stylesheets/buttons.less';

View File

@ -24,7 +24,9 @@ import { initFeatureFlags } from 'src/featureFlags';
import { initEnhancer } from '../reduxUtils';
import getInitialState from './reducers/getInitialState';
import rootReducer from './reducers/index';
import initAsyncEvents from '../middleware/asyncEvent';
import logger from '../middleware/loggerMiddleware';
import * as actions from '../chart/chartAction';
import App from './App';
@ -33,10 +35,23 @@ const bootstrapData = JSON.parse(appContainer.getAttribute('data-bootstrap'));
initFeatureFlags(bootstrapData.common.feature_flags);
const initState = getInitialState(bootstrapData);
const asyncEventMiddleware = initAsyncEvents({
config: bootstrapData.common.conf,
getPendingComponents: ({ charts }) =>
Object.values(charts).filter(c => c.chartStatus === 'loading'),
successAction: (componentId, componentData) =>
actions.chartUpdateSucceeded(componentData, componentId),
errorAction: (componentId, response) =>
actions.chartUpdateFailed(response, componentId),
});
const store = createStore(
rootReducer,
initState,
compose(applyMiddleware(thunk, logger), initEnhancer(false)),
compose(
applyMiddleware(thunk, logger, asyncEventMiddleware),
initEnhancer(false),
),
);
ReactDOM.render(<App store={store} />, document.getElementById('app'));

View File

@ -27,7 +27,7 @@ import { Alert, FormControl, FormControlProps } from 'react-bootstrap';
import { SupersetClient, t } from '@superset-ui/core';
import TableView from 'src/components/TableView';
import Modal from 'src/common/components/Modal';
import getClientErrorObject from '../utils/getClientErrorObject';
import { getClientErrorObject } from '../utils/getClientErrorObject';
import Loading from '../components/Loading';
import withToasts from '../messageToasts/enhancers/withToasts';

View File

@ -33,7 +33,7 @@ import Loading from 'src/components/Loading';
import TableSelector from 'src/components/TableSelector';
import EditableTitle from 'src/components/EditableTitle';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import CheckboxControl from 'src/explore/components/controls/CheckboxControl';
import TextControl from 'src/explore/components/controls/TextControl';

View File

@ -25,7 +25,7 @@ import Modal from 'src/common/components/Modal';
import AsyncEsmComponent from 'src/components/AsyncEsmComponent';
import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import withToasts from 'src/messageToasts/enhancers/withToasts';
const DatasourceEditor = AsyncEsmComponent(() => import('./DatasourceEditor'));

View File

@ -23,7 +23,7 @@ import Tabs from 'src/common/components/Tabs';
import Loading from 'src/components/Loading';
import TableView, { EmptyWrapperType } from 'src/components/TableView';
import { getChartDataRequest } from 'src/chart/chartAction';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import {
CopyToClipboardButton,
FilterInput,

View File

@ -30,7 +30,7 @@ import { DropdownButton } from 'react-bootstrap';
import { styled, t } from '@superset-ui/core';
import { Menu } from 'src/common/components';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
import CopyToClipboard from '../../components/CopyToClipboard';
import { getChartDataRequest } from '../../chart/chartAction';
import downloadAsImage from '../../utils/downloadAsImage';

View File

@ -32,7 +32,7 @@ import rison from 'rison';
import { t, SupersetClient } from '@superset-ui/core';
import Chart, { Slice } from 'src/types/Chart';
import FormLabel from 'src/components/FormLabel';
import getClientErrorObject from '../../utils/getClientErrorObject';
import { getClientErrorObject } from '../../utils/getClientErrorObject';
type PropertiesModalProps = {
slice: Slice;

View File

@ -25,6 +25,8 @@ import { initFeatureFlags } from '../featureFlags';
import { initEnhancer } from '../reduxUtils';
import getInitialState from './reducers/getInitialState';
import rootReducer from './reducers/index';
import initAsyncEvents from '../middleware/asyncEvent';
import * as actions from '../chart/chartAction';
import App from './App';
@ -35,10 +37,23 @@ const bootstrapData = JSON.parse(
initFeatureFlags(bootstrapData.common.feature_flags);
const initState = getInitialState(bootstrapData);
const asyncEventMiddleware = initAsyncEvents({
config: bootstrapData.common.conf,
getPendingComponents: ({ charts }) =>
Object.values(charts).filter(c => c.chartStatus === 'loading'),
successAction: (componentId, componentData) =>
actions.chartUpdateSucceeded(componentData, componentId),
errorAction: (componentId, response) =>
actions.chartUpdateFailed(response, componentId),
});
const store = createStore(
rootReducer,
initState,
compose(applyMiddleware(thunk, logger), initEnhancer(false)),
compose(
applyMiddleware(thunk, logger, asyncEventMiddleware),
initEnhancer(false),
),
);
ReactDOM.render(<App store={store} />, document.getElementById('app'));

View File

@ -34,6 +34,7 @@ export enum FeatureFlag {
DISPLAY_MARKDOWN_HTML = 'DISPLAY_MARKDOWN_HTML',
ESCAPE_MARKDOWN_HTML = 'ESCAPE_MARKDOWN_HTML',
VERSIONED_EXPORT = 'VERSIONED_EXPORT',
GLOBAL_ASYNC_QUERIES = 'GLOBAL_ASYNC_QUERIES',
}
export type FeatureFlagMap = {

View File

@ -0,0 +1,196 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { Middleware, MiddlewareAPI, Dispatch } from 'redux';
import { makeApi, SupersetClient } from '@superset-ui/core';
import { SupersetError } from 'src/components/ErrorMessage/types';
import { isFeatureEnabled, FeatureFlag } from '../featureFlags';
import {
getClientErrorObject,
parseErrorJson,
} from '../utils/getClientErrorObject';
export type AsyncEvent = {
id: string;
channel_id: string;
job_id: string;
user_id: string;
status: string;
errors: SupersetError[];
result_url: string;
};
type AsyncEventOptions = {
config: {
GLOBAL_ASYNC_QUERIES_TRANSPORT: string;
GLOBAL_ASYNC_QUERIES_POLLING_DELAY: number;
};
getPendingComponents: (state: any) => any[];
successAction: (componentId: number, componentData: any) => { type: string };
errorAction: (componentId: number, response: any) => { type: string };
processEventsCallback?: (events: AsyncEvent[]) => void; // this is currently used only for tests
};
type CachedDataResponse = {
componentId: number;
status: string;
data: any;
};
const initAsyncEvents = (options: AsyncEventOptions) => {
// TODO: implement websocket support
const TRANSPORT_POLLING = 'polling';
const {
config,
getPendingComponents,
successAction,
errorAction,
processEventsCallback,
} = options;
const transport = config.GLOBAL_ASYNC_QUERIES_TRANSPORT || TRANSPORT_POLLING;
const polling_delay = config.GLOBAL_ASYNC_QUERIES_POLLING_DELAY || 500;
const middleware: Middleware = (store: MiddlewareAPI) => (next: Dispatch) => {
const JOB_STATUS = {
PENDING: 'pending',
RUNNING: 'running',
ERROR: 'error',
DONE: 'done',
};
const LOCALSTORAGE_KEY = 'last_async_event_id';
const POLLING_URL = '/api/v1/async_event/';
let lastReceivedEventId: string | null;
try {
lastReceivedEventId = localStorage.getItem(LOCALSTORAGE_KEY);
} catch (err) {
console.warn('failed to fetch last event Id from localStorage');
}
const fetchEvents = makeApi<
{ last_id?: string | null },
{ result: AsyncEvent[] }
>({
method: 'GET',
endpoint: POLLING_URL,
});
const fetchCachedData = async (
asyncEvent: AsyncEvent,
componentId: number,
): Promise<CachedDataResponse> => {
let status = 'success';
let data;
try {
const { json } = await SupersetClient.get({
endpoint: asyncEvent.result_url,
});
data = 'result' in json ? json.result[0] : json;
} catch (response) {
status = 'error';
data = await getClientErrorObject(response);
}
return { componentId, status, data };
};
const setLastId = (asyncEvent: AsyncEvent) => {
lastReceivedEventId = asyncEvent.id;
try {
localStorage.setItem(LOCALSTORAGE_KEY, lastReceivedEventId as string);
} catch (err) {
console.warn('Error saving event ID to localStorage', err);
}
};
const processEvents = async () => {
const state = store.getState();
const queuedComponents = getPendingComponents(state);
const eventArgs = lastReceivedEventId
? { last_id: lastReceivedEventId }
: {};
const events: AsyncEvent[] = [];
if (queuedComponents && queuedComponents.length) {
try {
const { result: events } = await fetchEvents(eventArgs);
if (events && events.length) {
const componentsByJobId = queuedComponents.reduce((acc, item) => {
acc[item.asyncJobId] = item;
return acc;
}, {});
const fetchDataEvents: Promise<CachedDataResponse>[] = [];
events.forEach((asyncEvent: AsyncEvent) => {
const component = componentsByJobId[asyncEvent.job_id];
if (!component) {
console.warn(
'component not found for job_id',
asyncEvent.job_id,
);
return setLastId(asyncEvent);
}
const componentId = component.id;
switch (asyncEvent.status) {
case JOB_STATUS.DONE:
fetchDataEvents.push(
fetchCachedData(asyncEvent, componentId),
);
break;
case JOB_STATUS.ERROR:
store.dispatch(
errorAction(componentId, parseErrorJson(asyncEvent)),
);
break;
default:
console.warn('received event with status', asyncEvent.status);
}
return setLastId(asyncEvent);
});
const fetchResults = await Promise.all(fetchDataEvents);
fetchResults.forEach(result => {
if (result.status === 'success') {
store.dispatch(successAction(result.componentId, result.data));
} else {
store.dispatch(errorAction(result.componentId, result.data));
}
});
}
} catch (err) {
console.warn(err);
}
}
if (processEventsCallback) processEventsCallback(events);
return setTimeout(processEvents, polling_delay);
};
if (
isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES) &&
transport === TRANSPORT_POLLING
)
processEvents();
return action => next(action);
};
return middleware;
};
export default initAsyncEvents;

View File

@ -19,7 +19,8 @@
/* eslint global-require: 0 */
import $ from 'jquery';
import { SupersetClient } from '@superset-ui/core';
import getClientErrorObject, {
import {
getClientErrorObject,
ClientErrorObject,
} from '../utils/getClientErrorObject';
import setupErrorMessages from './setupErrorMessages';

View File

@ -21,7 +21,7 @@ import {
getTimeFormatter,
TimeFormats,
} from '@superset-ui/core';
import getClientErrorObject from './getClientErrorObject';
import { getClientErrorObject } from './getClientErrorObject';
// ATTENTION: If you change any constants, make sure to also change constants.py

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { SupersetClientResponse, t } from '@superset-ui/core';
import { JsonObject, SupersetClientResponse, t } from '@superset-ui/core';
import {
SupersetError,
ErrorTypeEnum,
@ -36,7 +36,33 @@ export type ClientErrorObject = {
stacktrace?: string;
} & Partial<SupersetClientResponse>;
export default function getClientErrorObject(
export function parseErrorJson(responseObject: JsonObject): ClientErrorObject {
let error = { ...responseObject };
// Backwards compatibility for old error renderers with the new error object
if (error.errors && error.errors.length > 0) {
error.error = error.description = error.errors[0].message;
error.link = error.errors[0]?.extra?.link;
}
if (error.stack) {
error = {
...error,
error:
t('Unexpected error: ') +
(error.description || t('(no description, click to see stack trace)')),
stacktrace: error.stack,
};
} else if (error.responseText && error.responseText.indexOf('CSRF') >= 0) {
error = {
...error,
error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
};
}
return { ...error, error: error.error }; // explicit ClientErrorObject
}
export function getClientErrorObject(
response: SupersetClientResponse | (Response & { timeout: number }) | string,
): Promise<ClientErrorObject> {
// takes a SupersetClientResponse as input, attempts to read response as Json if possible,
@ -54,33 +80,8 @@ export default function getClientErrorObject(
.clone()
.json()
.then(errorJson => {
let error = { ...responseObject, ...errorJson };
// Backwards compatibility for old error renderers with the new error object
if (error.errors && error.errors.length > 0) {
error.error = error.description = error.errors[0].message;
error.link = error.errors[0]?.extra?.link;
}
if (error.stack) {
error = {
...error,
error:
t('Unexpected error: ') +
(error.description ||
t('(no description, click to see stack trace)')),
stacktrace: error.stack,
};
} else if (
error.responseText &&
error.responseText.indexOf('CSRF') >= 0
) {
error = {
...error,
error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
};
}
resolve(error);
const error = { ...responseObject, ...errorJson };
resolve(parseErrorJson(error));
})
.catch(() => {
// fall back to reading as text

View File

@ -29,7 +29,7 @@ import ConfirmStatusChange from 'src/components/ConfirmStatusChange';
import DeleteModal from 'src/components/DeleteModal';
import ListView, { ListViewProps } from 'src/components/ListView';
import SubMenu, { SubMenuProps } from 'src/components/Menu/SubMenu';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import withToasts from 'src/messageToasts/enhancers/withToasts';
import { IconName } from 'src/components/Icon';
import { useListViewResource } from 'src/views/CRUD/hooks';

View File

@ -21,7 +21,7 @@ import { styled, t, SupersetClient } from '@superset-ui/core';
import InfoTooltip from 'src/common/components/InfoTooltip';
import { useSingleViewResource } from 'src/views/CRUD/hooks';
import withToasts from 'src/messageToasts/enhancers/withToasts';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import Icon from 'src/components/Icon';
import Modal from 'src/common/components/Modal';
import Tabs from 'src/common/components/Tabs';

View File

@ -25,7 +25,7 @@ import { FetchDataConfig } from 'src/components/ListView';
import { FilterValue } from 'src/components/ListView/types';
import Chart, { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import { FavoriteStatus } from './types';
interface ListViewResourceState<D extends object = any> {

View File

@ -25,7 +25,7 @@ import {
} from '@superset-ui/core';
import Chart from 'src/types/Chart';
import rison from 'rison';
import getClientErrorObject from 'src/utils/getClientErrorObject';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import { FetchDataConfig } from 'src/components/ListView';
import { Dashboard } from './types';

View File

@ -30,6 +30,7 @@ from superset.extensions import (
_event_logger,
APP_DIR,
appbuilder,
async_query_manager,
cache_manager,
celery_app,
csrf,
@ -127,6 +128,7 @@ class SupersetAppInitializer:
# pylint: disable=too-many-branches
from superset.annotation_layers.api import AnnotationLayerRestApi
from superset.annotation_layers.annotations.api import AnnotationRestApi
from superset.async_events.api import AsyncEventsRestApi
from superset.cachekeys.api import CacheRestApi
from superset.charts.api import ChartRestApi
from superset.connectors.druid.views import (
@ -201,6 +203,7 @@ class SupersetAppInitializer:
#
appbuilder.add_api(AnnotationRestApi)
appbuilder.add_api(AnnotationLayerRestApi)
appbuilder.add_api(AsyncEventsRestApi)
appbuilder.add_api(CacheRestApi)
appbuilder.add_api(ChartRestApi)
appbuilder.add_api(CssTemplateRestApi)
@ -498,6 +501,7 @@ class SupersetAppInitializer:
self.configure_url_map_converters()
self.configure_data_sources()
self.configure_auth_provider()
self.configure_async_queries()
# Hook that provides administrators a handle on the Flask APP
# after initialization
@ -648,6 +652,10 @@ class SupersetAppInitializer:
for ex in csrf_exempt_list:
csrf.exempt(ex)
def configure_async_queries(self) -> None:
if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
async_query_manager.init_app(self.flask_app)
def register_blueprints(self) -> None:
for bp in self.config["BLUEPRINTS"]:
try:

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from flask import request, Response
from flask_appbuilder import expose
from flask_appbuilder.api import BaseApi, safe
from flask_appbuilder.security.decorators import permission_name, protect
from superset.extensions import async_query_manager, event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
logger = logging.getLogger(__name__)
class AsyncEventsRestApi(BaseApi):
resource_name = "async_event"
allow_browser_login = True
include_route_methods = {
"events",
}
@expose("/", methods=["GET"])
@event_logger.log_this
@protect()
@safe
@permission_name("list")
def events(self) -> Response:
"""
Reads off of the Redis async events stream, using the user's JWT token and
optional query params for last event received.
---
get:
description: >-
Reads off of the Redis events stream, using the user's JWT token and
optional query params for last event received.
parameters:
- in: query
name: last_id
description: Last ID received by the client
schema:
type: string
responses:
200:
description: Async event results
content:
application/json:
schema:
type: object
properties:
result:
type: array
items:
type: object
properties:
id:
type: string
channel_id:
type: string
job_id:
type: string
user_id:
type: integer
status:
type: string
msg:
type: string
cache_key:
type: string
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
try:
async_channel_id = async_query_manager.parse_jwt_from_request(request)[
"channel"
]
last_event_id = request.args.get("last_id")
events = async_query_manager.read_events(async_channel_id, last_event_id)
except AsyncQueryTokenException:
return self.response_401()
return self.response(200, result=events)

View File

@ -33,10 +33,13 @@ from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset.charts.commands.bulk_delete import BulkDeleteChartCommand
from superset.charts.commands.create import CreateChartCommand
from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.delete import DeleteChartCommand
from superset.charts.commands.exceptions import (
ChartBulkDeleteFailedError,
ChartCreateFailedError,
ChartDataCacheLoadError,
ChartDataQueryFailedError,
ChartDeleteFailedError,
ChartForbiddenError,
ChartInvalidError,
@ -50,7 +53,6 @@ from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.schemas import (
CHART_SCHEMAS,
ChartDataQueryContextSchema,
ChartPostSchema,
ChartPutSchema,
get_delete_ids_schema,
@ -67,7 +69,12 @@ from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils.core import ChartDataResultFormat, json_int_dttm_ser
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import (
ChartDataResultFormat,
ChartDataResultType,
json_int_dttm_ser,
)
from superset.utils.screenshots import ChartScreenshot
from superset.utils.urls import get_url_path
from superset.views.base_api import (
@ -93,6 +100,8 @@ class ChartRestApi(BaseSupersetModelRestApi):
RouteMethod.RELATED,
"bulk_delete", # not using RouteMethod since locally defined
"data",
"data_from_cache",
"viz_types",
"favorite_status",
}
class_permission_name = "SliceModelView"
@ -448,6 +457,39 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ChartBulkDeleteFailedError as ex:
return self.response_422(message=str(ex))
def get_data_response(
self, command: ChartDataCommand, force_cached: bool = False
) -> Response:
try:
result = command.run(force_cached=force_cached)
except ChartDataCacheLoadError as exc:
return self.response_422(message=exc.message)
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)
result_format = result["query_context"].result_format
if result_format == ChartDataResultFormat.CSV:
# return the first result
data = result["queries"][0]["data"]
return CsvResponse(
data,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv",
)
if result_format == ChartDataResultFormat.JSON:
response_data = simplejson.dumps(
{"result": result["queries"]},
default=json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
return self.response_400(message=f"Unsupported result_format: {result_format}")
@expose("/data", methods=["POST"])
@protect()
@safe
@ -478,8 +520,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
202:
description: Async job details
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
@ -490,47 +540,88 @@ class ChartRestApi(BaseSupersetModelRestApi):
json_body = json.loads(request.form["form_data"])
else:
return self.response_400(message="Request is not JSON")
try:
query_context = ChartDataQueryContextSchema().load(json_body)
except KeyError:
return self.response_400(message="Request is incorrect")
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
command.validate()
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
)
try:
query_context.raise_for_access()
except SupersetSecurityException:
return self.response_401()
payload = query_context.get_payload()
for query in payload:
if query.get("error"):
return self.response_400(message=f"Error: {query['error']}")
result_format = query_context.result_format
response = self.response_400(
message=f"Unsupported result_format: {result_format}"
)
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
if result_format == ChartDataResultFormat.CSV:
# return the first result
result = payload[0]["data"]
response = CsvResponse(
result,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv",
try:
command.validate_async_request(request)
except AsyncQueryTokenException:
return self.response_401()
result = command.run_async()
return self.response(202, **result)
return self.get_data_response(command)
@expose("/data/<cache_key>", methods=["GET"])
@event_logger.log_this
@protect()
@safe
@statsd_metrics
def data_from_cache(self, cache_key: str) -> Response:
"""
Takes a query context cache key and returns payload
data response for the given query.
---
get:
description: >-
Takes a query context cache key and returns payload data
response for the given query.
parameters:
- in: path
schema:
type: string
name: cache_key
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
command = ChartDataCommand()
try:
cached_data = command.load_query_context_from_cache(cache_key)
command.set_query_context(cached_data)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
)
except SupersetSecurityException as exc:
logger.info(exc)
return self.response_401()
if result_format == ChartDataResultFormat.JSON:
response_data = simplejson.dumps(
{"result": payload}, default=json_int_dttm_ser, ignore_nan=True
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
response = resp
return response
return self.get_data_response(command, True)
@expose("/<pk>/cache_screenshot/", methods=["GET"])
@protect()

View File

@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from flask import Request
from marshmallow import ValidationError
from superset import cache
from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.base import BaseCommand
from superset.common.query_context import QueryContext
from superset.exceptions import CacheLoadError
from superset.extensions import async_query_manager
from superset.tasks.async_queries import load_chart_data_into_cache
logger = logging.getLogger(__name__)
class ChartDataCommand(BaseCommand):
def __init__(self) -> None:
self._form_data: Dict[str, Any]
self._query_context: QueryContext
self._async_channel_id: str
def run(self, **kwargs: Any) -> Dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
cache_query_context = kwargs.get("cache", False)
force_cached = kwargs.get("force_cached", False)
try:
payload = self._query_context.get_payload(
cache_query_context=cache_query_context, force_cached=force_cached
)
except CacheLoadError as exc:
raise ChartDataCacheLoadError(exc.message)
# TODO: QueryContext should support SIP-40 style errors
for query in payload["queries"]:
if query.get("error"):
raise ChartDataQueryFailedError(f"Error: {query['error']}")
return_value = {
"query_context": self._query_context,
"queries": payload["queries"],
}
if cache_query_context:
return_value.update(cache_key=payload["cache_key"])
return return_value
def run_async(self) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id)
load_chart_data_into_cache.delay(job_metadata, self._form_data)
return job_metadata
def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
self._form_data = form_data
try:
self._query_context = ChartDataQueryContextSchema().load(self._form_data)
except KeyError:
raise ValidationError("Request is incorrect")
except ValidationError as error:
raise error
return self._query_context
def validate(self) -> None:
self._query_context.raise_for_access()
def validate_async_request(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
def load_query_context_from_cache( # pylint: disable=no-self-use
self, cache_key: str
) -> Dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")
return cache_value["data"]

View File

@ -90,6 +90,14 @@ class ChartBulkDeleteFailedError(DeleteFailedError):
message = _("Charts could not be deleted.")
class ChartDataQueryFailedError(CommandException):
pass
class ChartDataCacheLoadError(CommandException):
pass
class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError):
message = _("There are associated alerts or reports")

View File

@ -1105,6 +1105,18 @@ class ChartDataResponseSchema(Schema):
)
class ChartDataAsyncResponseSchema(Schema):
channel_id = fields.String(
description="Unique session async channel ID", allow_none=False,
)
job_id = fields.String(description="Unique async job ID", allow_none=False,)
user_id = fields.String(description="Requesting user ID", allow_none=True,)
status = fields.String(description="Status value for async job", allow_none=False,)
result_url = fields.String(
description="Unique result URL for fetching async query data", allow_none=False,
)
class ChartFavStarResponseResult(Schema):
id = fields.Integer(description="The Chart id")
value = fields.Boolean(description="The FaveStar value")
@ -1130,6 +1142,7 @@ class ImportV1ChartSchema(Schema):
CHART_SCHEMAS = (
ChartDataQueryContextSchema,
ChartDataResponseSchema,
ChartDataAsyncResponseSchema,
# TODO: These should optimally be included in the QueryContext schema as an `anyOf`
# in ChartDataPostPricessingOperation.options, but since `anyOf` is not
# by Marshmallow<3, this is not currently possible.

View File

@ -17,7 +17,7 @@
import copy
import logging
import math
from datetime import datetime, timedelta
from datetime import timedelta
from typing import Any, cast, ClassVar, Dict, List, Optional, Union
import numpy as np
@ -30,13 +30,17 @@ from superset.charts.dao import ChartDAO
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.exceptions import (
CacheLoadError,
QueryObjectValidationError,
SupersetException,
)
from superset.extensions import cache_manager, security_manager
from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import DTTM_ALIAS
from superset.views.utils import get_viz
from superset.viz import set_and_log_cache
config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
@ -78,6 +82,13 @@ class QueryContext:
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or utils.ChartDataResultType.FULL
self.result_format = result_format or utils.ChartDataResultFormat.JSON
self.cache_values = {
"datasource": datasource,
"queries": queries,
"force": force,
"result_type": result_type,
"result_format": result_format,
}
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
@ -142,8 +153,11 @@ class QueryContext:
return df.to_dict(orient="records")
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
def get_single_payload(
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
force_cached = kwargs.get("force_cached", False)
if self.result_type == utils.ChartDataResultType.QUERY:
return {
"query": self.datasource.get_query_str(query_obj.to_dict()),
@ -159,8 +173,7 @@ class QueryContext:
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in self.datasource.columns]
payload = self.get_df_payload(query_obj)
payload = self.get_df_payload(query_obj, force_cached=force_cached)
df = payload["df"]
status = payload["status"]
if status != utils.QueryStatus.FAILED:
@ -186,9 +199,28 @@ class QueryContext:
return {"data": payload["data"]}
return payload
def get_payload(self) -> List[Dict[str, Any]]:
"""Get all the payloads from the QueryObjects"""
return [self.get_single_payload(query_object) for query_object in self.queries]
def get_payload(self, **kwargs: Any) -> Dict[str, Any]:
cache_query_context = kwargs.get("cache_query_context", False)
force_cached = kwargs.get("force_cached", False)
# Get all the payloads from the QueryObjects
query_results = [
self.get_single_payload(query_object, force_cached=force_cached)
for query_object in self.queries
]
return_value = {"queries": query_results}
if cache_query_context:
cache_key = self.cache_key()
set_and_log_cache(
cache_manager.cache,
cache_key,
{"data": self.cache_values},
self.cache_timeout,
)
return_value["cache_key"] = cache_key # type: ignore
return return_value
@property
def cache_timeout(self) -> int:
@ -203,7 +235,22 @@ class QueryContext:
return self.datasource.database.cache_timeout
return config["CACHE_DEFAULT_TIMEOUT"]
def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
def cache_key(self, **extra: Any) -> str:
"""
The QueryContext cache key is made out of the key/values from
self.cached_values, plus any other key/values in `extra`. It includes only data
required to rehydrate a QueryContext object.
"""
key_prefix = "qc-"
cache_dict = self.cache_values.copy()
cache_dict.update(extra)
return generate_cache_key(cache_dict, key_prefix)
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
"""
Returns a QueryObject cache key for objects in self.queries
"""
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
@ -215,7 +262,7 @@ class QueryContext:
and self.datasource.is_rls_supported
else [],
changed_on=self.datasource.changed_on,
**kwargs
**kwargs,
)
if query_obj
else None
@ -298,12 +345,12 @@ class QueryContext:
self, query_obj: QueryObject, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.cache_key(query_obj, **kwargs)
force_cached = kwargs.get("force_cached", False)
cache_key = self.query_cache_key(query_obj)
logger.info("Cache key: %s", cache_key)
is_loaded = False
stacktrace = None
df = pd.DataFrame()
cached_dttm = datetime.utcnow().isoformat().split(".")[0]
cache_value = None
status = None
query = ""
@ -327,6 +374,12 @@ class QueryContext:
)
logger.info("Serving from cache")
if force_cached and not is_loaded:
logger.warning(
"force_cached (QueryContext): value not found for key %s", cache_key
)
raise CacheLoadError("Error loading data from cache")
if query_obj and not is_loaded:
try:
invalid_columns = [
@ -367,13 +420,11 @@ class QueryContext:
if is_loaded and cache_key and status != utils.QueryStatus.FAILED:
set_and_log_cache(
cache_key=cache_key,
df=df,
query=query,
annotation_data=annotation_data,
cached_dttm=cached_dttm,
cache_timeout=self.cache_timeout,
datasource_uid=self.datasource.uid,
cache_manager.data_cache,
cache_key,
{"df": df, "query": query, "annotation_data": annotation_data},
self.cache_timeout,
self.datasource.uid,
)
return {
"cache_key": cache_key,

View File

@ -277,7 +277,8 @@ class QueryObject:
:param df: DataFrame returned from database model.
:return: new DataFrame to which all post processing operations have been
applied
:raises ChartDataValidationError: If the post processing operation in incorrect
:raises QueryObjectValidationError: If the post processing operation
is incorrect
"""
for post_process in self.post_processing:
operation = post_process.get("operation")

View File

@ -327,6 +327,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
"DISPLAY_MARKDOWN_HTML": True,
# When True, this escapes HTML (rather than rendering it) in Markdown components
"ESCAPE_MARKDOWN_HTML": False,
"GLOBAL_ASYNC_QUERIES": False,
"VERSIONED_EXPORT": False,
# Note that: RowLevelSecurityFilter is only given by default to the Admin role
# and the Admin Role does have the all_datasources security permission.
@ -406,6 +407,9 @@ CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
# Cache for datasource metadata and query results
DATA_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation
STORE_CACHE_KEYS_IN_METADATA_DB = False
# CORS Options
ENABLE_CORS = False
CORS_OPTIONS: Dict[Any, Any] = {}
@ -965,6 +969,23 @@ SIP_15_TOAST_MESSAGE = (
# conventions and such. You can find examples in the tests.
SQLA_TABLE_MUTATOR = lambda table: table
# Global async query config options.
# Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled.
GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = {
"port": 6379,
"host": "127.0.0.1",
"password": "",
"db": 0,
}
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = 500
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not necessarily in pythonpath; useful
# for case where app is being executed via pex.

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from flask_babel import gettext as _
@ -65,6 +65,14 @@ class SupersetSecurityException(SupersetException):
self.payload = payload
class SupersetVizException(SupersetException):
status = 400
def __init__(self, errors: List[SupersetError]) -> None:
super(SupersetVizException, self).__init__(str(errors))
self.errors = errors
class NoDataException(SupersetException):
status = 400
@ -93,6 +101,10 @@ class QueryObjectValidationError(SupersetException):
status = 400
class CacheLoadError(SupersetException):
status = 404
class DashboardImportException(SupersetException):
pass

View File

@ -27,6 +27,7 @@ from flask_talisman import Talisman
from flask_wtf.csrf import CSRFProtect
from werkzeug.local import LocalProxy
from superset.utils.async_query_manager import AsyncQueryManager
from superset.utils.cache_manager import CacheManager
from superset.utils.feature_flag_manager import FeatureFlagManager
from superset.utils.machine_auth import MachineAuthProviderFactory
@ -97,6 +98,7 @@ class UIManifestProcessor:
APP_DIR = os.path.dirname(__file__)
appbuilder = AppBuilder(update_perms=False)
async_query_manager = AsyncQueryManager()
cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()

View File

@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, cast, Dict, Optional
from flask import current_app
from superset import app
from superset.exceptions import SupersetVizException
from superset.extensions import async_query_manager, cache_manager, celery_app
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.views.utils import get_datasource_info, get_viz
logger = logging.getLogger(__name__)
query_timeout = current_app.config[
"SQLLAB_ASYNC_TIME_LIMIT_SEC"
] # TODO: new config key
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None:
from superset.charts.commands.data import (
ChartDataCommand,
) # load here due to circular imports
with app.app_context(): # type: ignore
try:
command = ChartDataCommand()
command.set_query_context(form_data)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except Exception as exc:
# TODO: QueryContext should support SIP-40 style errors
error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
errors = [{"message": error}]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise exc
return None
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
def load_explore_json_into_cache(
job_metadata: Dict[str, Any],
form_data: Dict[str, Any],
response_type: Optional[str] = None,
force: bool = False,
) -> None:
with app.app_context(): # type: ignore
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=force,
)
# run query & cache results
payload = viz_obj.get_payload()
if viz_obj.has_error(payload):
raise SupersetVizException(errors=payload["errors"])
# cache form_data for async retrieval
cache_value = {"form_data": form_data, "response_type": response_type}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except Exception as exc:
if isinstance(exc, SupersetVizException):
errors = exc.errors # pylint: disable=no-member
else:
error = (
exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
)
errors = [error]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise exc
return None

View File

@ -0,0 +1,199 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
import uuid
from typing import Any, Dict, List, Optional, Tuple
import jwt
import redis
from flask import Flask, Request, Response, session
logger = logging.getLogger(__name__)
class AsyncQueryTokenException(Exception):
pass
class AsyncQueryJobException(Exception):
pass
def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": session.get("user_id"),
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
}
def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]:
event_id = event_data[0]
event_payload = event_data[1]["data"]
return {"id": event_id, **json.loads(event_payload)}
def increment_id(redis_id: str) -> str:
# redis stream IDs are in this format: '1607477697866-0'
try:
prefix, last = redis_id[:-1], int(redis_id[-1])
return prefix + str(last + 1)
except Exception: # pylint: disable=broad-except
return redis_id
class AsyncQueryManager:
MAX_EVENT_COUNT = 100
STATUS_PENDING = "pending"
STATUS_RUNNING = "running"
STATUS_ERROR = "error"
STATUS_DONE = "done"
def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
self._jwt_cookie_name: str
self._jwt_cookie_secure: bool = False
self._jwt_secret: str
def init_app(self, app: Flask) -> None:
config = app.config
if (
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
):
raise Exception(
"""
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
and non-null in order to enable async queries
"""
)
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
raise AsyncQueryTokenException(
"Please provide a JWT secret at least 32 bytes long"
)
self._redis = redis.Redis( # type: ignore
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
self._stream_limit_firehose = config[
"GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE"
]
self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"]
self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"]
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
@app.after_request
def validate_session( # pylint: disable=unused-variable
response: Response,
) -> Response:
reset_token = False
user_id = session["user_id"] if "user_id" in session else None
if "async_channel_id" not in session or "async_user_id" not in session:
reset_token = True
elif user_id != session["async_user_id"]:
reset_token = True
if reset_token:
async_channel_id = str(uuid.uuid4())
session["async_channel_id"] = async_channel_id
session["async_user_id"] = user_id
sub = str(user_id) if user_id else None
token = self.generate_jwt({"channel": async_channel_id, "sub": sub})
response.set_cookie(
self._jwt_cookie_name,
value=token,
httponly=True,
secure=self._jwt_cookie_secure,
# max_age=max_age or config.cookie_max_age,
# domain=config.cookie_domain,
# path=config.access_cookie_path,
# samesite=config.cookie_samesite
)
return response
def generate_jwt(self, data: Dict[str, Any]) -> str:
encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256")
return encoded_jwt.decode("utf-8")
def parse_jwt(self, token: str) -> Dict[str, Any]:
data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
return data
def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]:
token = request.cookies.get(self._jwt_cookie_name)
if not token:
raise AsyncQueryTokenException("Token not preset")
try:
return self.parse_jwt(token)
except Exception as exc:
logger.warning(exc)
raise AsyncQueryTokenException("Failed to parse token")
def init_job(self, channel_id: str) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING)
def read_events(
self, channel: str, last_id: Optional[str]
) -> List[Optional[Dict[str, Any]]]:
stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange( # type: ignore
stream_name, start_id, "+", self.MAX_EVENT_COUNT
)
return [] if not results else list(map(parse_event, results))
def update_job(
self, job_metadata: Dict[str, Any], status: str, **kwargs: Any
) -> None:
if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")
if "job_id" not in job_metadata:
raise AsyncQueryJobException("No job ID specified")
updates = {"status": status, **kwargs}
event_data = {"data": json.dumps({**job_metadata, **updates})}
full_stream_name = f"{self._stream_prefix}full"
scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}"
logger.debug("********** logging event data to stream %s", scoped_stream_name)
logger.debug(event_data)
self._redis.xadd( # type: ignore
scoped_stream_name, event_data, "*", self._stream_limit
)
self._redis.xadd( # type: ignore
full_stream_name, event_data, "*", self._stream_limit_firehose
)

View File

@ -14,16 +14,65 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import hashlib
import json
import logging
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
from flask import current_app as app, request
from flask_caching import Cache
from werkzeug.wrappers.etag import ETagResponseMixin
from superset import db
from superset.extensions import cache_manager
from superset.models.cache import CacheKey
from superset.stats_logger import BaseStatsLogger
from superset.utils.core import json_int_dttm_ser
config = app.config # type: ignore
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
# TODO: DRY up cache key code
def json_dumps(obj: Any, sort_keys: bool = False) -> str:
return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys)
def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str:
json_data = json_dumps(values_dict, sort_keys=True)
hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest()
return f"{key_prefix}{hash_str}"
def set_and_log_cache(
cache_instance: Cache,
cache_key: str,
cache_value: Dict[str, Any],
cache_timeout: Optional[int] = None,
datasource_uid: Optional[str] = None,
) -> None:
timeout = cache_timeout if cache_timeout else config["CACHE_DEFAULT_TIMEOUT"]
try:
dttm = datetime.utcnow().isoformat().split(".")[0]
value = {**cache_value, "dttm": dttm}
cache_instance.set(cache_key, value, timeout=timeout)
stats_logger.incr("set_cache_key")
if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
ck = CacheKey(
cache_key=cache_key,
cache_timeout=cache_timeout,
datasource_uid=datasource_uid,
)
db.session.add(ck)
except Exception as ex: # pylint: disable=broad-except
# cache.set call can fail if the backend is down or if
# the key is too large or whatever other reasons
logger.warning("Could not cache key %s", cache_key)
logger.exception(ex)
# If a user sets `max_age` to 0, for long the browser should cache the
# resource? Flask-Caching will cache forever, but for the HTTP header we need

View File

@ -238,7 +238,7 @@ def pivot( # pylint: disable=too-many-arguments
Default to 'All'.
:param flatten_columns: Convert column names to strings
:return: A pivot table
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
@ -293,7 +293,7 @@ def aggregate(
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
@ -313,7 +313,7 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
@ -348,7 +348,7 @@ def rolling( # pylint: disable=too-many-arguments
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:return: DataFrame with the rolling columns
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
@ -408,7 +408,7 @@ def select(
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
df_select = df.copy(deep=False)
if columns:
@ -433,7 +433,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
unchanged.
:param periods: periods to shift for calculating difference.
:return: DataFrame with diffed columns
:raises ChartDataValidationError: If the request in incorrect
:raises QueryObjectValidationError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods)

View File

@ -44,7 +44,8 @@ class Api(BaseSupersetView):
"""
query_context = QueryContext(**json.loads(request.form["query_context"]))
query_context.raise_for_access()
payload_json = query_context.get_payload()
result = query_context.get_payload()
payload_json = result["queries"]
return json.dumps(
payload_json, default=utils.json_int_dttm_ser, ignore_nan=True
)

View File

@ -77,6 +77,8 @@ FRONTEND_CONF_KEYS = (
"SUPERSET_WEBSERVER_DOMAINS",
"SQLLAB_SAVE_WARNING_MESSAGE",
"DISPLAY_MAX_ROW",
"GLOBAL_ASYNC_QUERIES_TRANSPORT",
"GLOBAL_ASYNC_QUERIES_POLLING_DELAY",
)
logger = logging.getLogger(__name__)

View File

@ -126,6 +126,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
method_permission_name = {
"bulk_delete": "delete",
"data": "list",
"data_from_cache": "list",
"delete": "delete",
"distinct": "list",
"export": "mulexport",

View File

@ -28,7 +28,11 @@ import simplejson as json
from flask import abort, flash, g, Markup, redirect, render_template, request, Response
from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access, has_access_api
from flask_appbuilder.security.decorators import (
has_access,
has_access_api,
permission_name,
)
from flask_appbuilder.security.sqla import models as ab_models
from flask_babel import gettext as __, lazy_gettext as _
from jinja2.exceptions import TemplateError
@ -66,6 +70,7 @@ from superset.dashboards.dao import DashboardDAO
from superset.databases.dao import DatabaseDAO
from superset.databases.filters import DatabaseFilter
from superset.exceptions import (
CacheLoadError,
CertificateException,
DatabaseNotFound,
SerializationError,
@ -73,6 +78,7 @@ from superset.exceptions import (
SupersetSecurityException,
SupersetTimeoutException,
)
from superset.extensions import async_query_manager, cache_manager
from superset.jinja_context import get_template_processor
from superset.models.core import Database, FavStar, Log
from superset.models.dashboard import Dashboard
@ -87,8 +93,10 @@ from superset.security.analytics_db_safety import (
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.tasks.async_queries import load_explore_json_into_cache
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.dates import now_as_float
from superset.views.base import (
@ -113,6 +121,7 @@ from superset.views.utils import (
apply_display_max_row_limit,
bootstrap_user_data,
check_datasource_perms,
check_explore_cache_perms,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
@ -484,6 +493,43 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
payload = viz_obj.get_payload()
return data_payload_response(*viz_obj.payload_json_and_has_error(payload))
@event_logger.log_this
@api
@has_access_api
@handle_api_exception
@permission_name("explore_json")
@expose("/explore_json/data/<cache_key>", methods=["GET"])
@etag_cache(check_perms=check_explore_cache_perms)
def explore_json_data(self, cache_key: str) -> FlaskResponse:
"""Serves cached result data for async explore_json calls
`self.generate_json` receives this input and returns different
payloads based on the request args in the first block
TODO: form_data should not be loaded twice from cache
(also loaded in `check_explore_cache_perms`)
"""
try:
cached = cache_manager.cache.get(cache_key)
if not cached:
raise CacheLoadError("Cached data not found")
form_data = cached.get("form_data")
response_type = cached.get("response_type")
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force_cached=True,
)
return self.generate_json(viz_obj, response_type)
except SupersetException as ex:
return json_error_response(utils.error_msg_from_exception(ex), 400)
EXPLORE_JSON_METHODS = ["POST"]
if not is_feature_enabled("ENABLE_EXPLORE_JSON_CSRF_PROTECTION"):
EXPLORE_JSON_METHODS.append("GET")
@ -528,11 +574,31 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource_id, datasource_type, form_data
)
force = request.args.get("force") == "true"
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and response_type == utils.ChartDataResultFormat.JSON
):
try:
async_channel_id = async_query_manager.parse_jwt_from_request(
request
)["channel"]
job_metadata = async_query_manager.init_job(async_channel_id)
load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force
)
except AsyncQueryTokenException:
return json_error_response("Not authorized", 401)
return json_success(json.dumps(job_metadata), status=202)
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=request.args.get("force") == "true",
force=force,
)
return self.generate_json(viz_obj, response_type)

View File

@ -34,10 +34,12 @@ from superset import app, dataframe, db, is_feature_enabled, result_set
from superset.connectors.connector_registry import ConnectorRegistry
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
CacheLoadError,
SerializationError,
SupersetException,
SupersetSecurityException,
)
from superset.extensions import cache_manager
from superset.legacy import update_time_range
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@ -108,13 +110,19 @@ def get_permissions(
def get_viz(
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
form_data: FormData,
datasource_type: str,
datasource_id: int,
force: bool = False,
force_cached: bool = False,
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
viz_obj = viz.viz_types[viz_type](
datasource, form_data=form_data, force=force, force_cached=force_cached
)
return viz_obj
@ -422,10 +430,26 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
return obj and user in obj.owners
def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
"""
Loads async explore_json request data from cache and performs access check
:param _self: the Superset view instance
:param cache_key: the cache key passed into /explore_json/data/
:raises SupersetSecurityException: If the user cannot access the resource
"""
cached = cache_manager.cache.get(cache_key)
if not cached:
raise CacheLoadError("Cached data not found")
check_datasource_perms(_self, form_data=cached["form_data"])
def check_datasource_perms(
_self: Any,
datasource_type: Optional[str] = None,
datasource_id: Optional[int] = None,
**kwargs: Any
) -> None:
"""
Check if user can access a cached response from explore_json.
@ -438,7 +462,7 @@ def check_datasource_perms(
:raises SupersetSecurityException: If the user cannot access the resource
"""
form_data = get_form_data()[0]
form_data = kwargs["form_data"] if "form_data" in kwargs else get_form_data()[0]
try:
datasource_id, datasource_type = get_datasource_info(

View File

@ -38,6 +38,7 @@ from typing import (
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
@ -57,6 +58,7 @@ from superset import app, db, is_feature_enabled
from superset.constants import NULL_STRING
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
CacheLoadError,
NullValueException,
QueryObjectValidationError,
SpatialException,
@ -66,6 +68,7 @@ from superset.models.cache import CacheKey
from superset.models.helpers import QueryResult
from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
DTTM_ALIAS,
JS_MAX_INTEGER,
@ -97,37 +100,6 @@ METRIC_KEYS = [
]
def set_and_log_cache(
cache_key: str,
df: pd.DataFrame,
query: str,
cached_dttm: str,
cache_timeout: int,
datasource_uid: Optional[str],
annotation_data: Optional[Dict[str, Any]] = None,
) -> None:
try:
cache_value = dict(
dttm=cached_dttm, df=df, query=query, annotation_data=annotation_data or {}
)
stats_logger.incr("set_cache_key")
cache_manager.data_cache.set(cache_key, cache_value, timeout=cache_timeout)
if datasource_uid:
ck = CacheKey(
cache_key=cache_key,
cache_timeout=cache_timeout,
datasource_uid=datasource_uid,
)
db.session.add(ck)
except Exception as ex:
# cache.set call can fail if the backend is down or if
# the key is too large or whatever other reasons
logger.warning("Could not cache key {}".format(cache_key))
logger.exception(ex)
cache_manager.data_cache.delete(cache_key)
class BaseViz:
"""All visualizations derive this base class"""
@ -144,6 +116,7 @@ class BaseViz:
datasource: "BaseDatasource",
form_data: Dict[str, Any],
force: bool = False,
force_cached: bool = False,
) -> None:
if not datasource:
raise QueryObjectValidationError(_("Viz is missing a datasource"))
@ -164,6 +137,7 @@ class BaseViz:
self.results: Optional[QueryResult] = None
self.errors: List[Dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None
@ -180,6 +154,10 @@ class BaseViz:
self.applied_filters: List[Dict[str, str]] = []
self.rejected_filters: List[Dict[str, str]] = []
@property
def force_cached(self) -> bool:
return self._force_cached
def process_metrics(self) -> None:
# metrics in TableViz is order sensitive, so metric_dict should be
# OrderedDict
@ -272,7 +250,7 @@ class BaseViz:
"columns": [o.column_name for o in self.datasource.columns],
}
)
df = self.get_df(query_obj)
df = self.get_df_payload(query_obj)["df"] # leverage caching logic
return df.to_dict(orient="records")
def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
@ -518,7 +496,6 @@ class BaseViz:
is_loaded = False
stacktrace = None
df = None
cached_dttm = datetime.utcnow().isoformat().split(".")[0]
if cache_key and cache_manager.data_cache and not self.force:
cache_value = cache_manager.data_cache.get(cache_key)
if cache_value:
@ -539,6 +516,11 @@ class BaseViz:
logger.info("Serving from cache")
if query_obj and not is_loaded:
if self.force_cached:
logger.warning(
f"force_cached (viz.py): value not found for cache key {cache_key}"
)
raise CacheLoadError(_("Cached value not found"))
try:
invalid_columns = [
col
@ -590,12 +572,11 @@ class BaseViz:
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
set_and_log_cache(
cache_key=cache_key,
df=df,
query=self.query,
cached_dttm=cached_dttm,
cache_timeout=self.cache_timeout,
datasource_uid=self.datasource.uid,
cache_manager.data_cache,
cache_key,
{"df": df, "query": self.query},
self.cache_timeout,
self.datasource.uid,
)
return {
"cache_key": self._any_cache_key,
@ -618,13 +599,15 @@ class BaseViz:
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
has_error = (
def has_error(self, payload: VizPayload) -> bool:
return (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
or bool(payload.get("errors"))
)
return self.json_dumps(payload), has_error
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
return self.json_dumps(payload), self.has_error(payload)
@property
def data(self) -> Dict[str, Any]:
@ -638,7 +621,7 @@ class BaseViz:
return content
def get_csv(self) -> Optional[str]:
df = self.get_df()
df = self.get_df_payload()["df"] # leverage caching logic
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
@ -1721,7 +1704,7 @@ class SunburstViz(BaseViz):
def get_data(self, df: pd.DataFrame) -> VizData:
if df.empty:
return None
fd = self.form_data
fd = copy.deepcopy(self.form_data)
cols = fd.get("groupby") or []
cols.extend(["m1", "m2"])
metric = utils.get_metric_name(fd["metric"])
@ -2983,12 +2966,14 @@ class PartitionViz(NVD3TimeSeriesViz):
return self.nest_values(levels)
def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]:
return set(cls.__subclasses__()).union(
[sc for c in cls.__subclasses__() for sc in get_subclasses(c)]
)
viz_types = {
o.viz_type: o
for o in globals().values()
if (
inspect.isclass(o)
and issubclass(o, BaseViz)
and o.viz_type not in config["VIZ_TYPE_DENYLIST"]
)
for o in get_subclasses(BaseViz)
if o.viz_type not in config["VIZ_TYPE_DENYLIST"]
}

View File

@ -57,13 +57,13 @@ from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
DTTM_ALIAS,
JS_MAX_INTEGER,
merge_extra_filters,
to_adhoc,
)
from superset.viz import set_and_log_cache
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -518,10 +518,9 @@ class BaseViz:
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
set_and_log_cache(
cache_manager.data_cache,
cache_key,
df,
self.query,
cached_dttm,
{"df": df, "query": self.query},
self.cache_timeout,
self.datasource.uid,
)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Optional
from unittest import mock
from superset.extensions import async_query_manager
from tests.base_tests import SupersetTestCase
from tests.test_app import app
class TestAsyncEventApi(SupersetTestCase):
UUID = "943c920-32a5-412a-977d-b8e47d36f5a4"
def fetch_events(self, last_id: Optional[str] = None):
base_uri = "api/v1/async_event/"
uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri
return self.client.get(uri)
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events(self, mock_uuid4):
async_query_manager.init_app(app)
self.login(username="admin")
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
rv = self.fetch_events()
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
self.assertEqual(response, {"result": []})
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_last_id(self, mock_uuid4):
async_query_manager.init_app(app)
self.login(username="admin")
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
rv = self.fetch_events("1607471525180-0")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
self.assertEqual(response, {"result": []})
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_results(self, mock_uuid4):
async_query_manager.init_app(app)
self.login(username="admin")
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
mock_xrange.return_value = [
(
"1607477697866-0",
{
"data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463"}'
},
),
(
"1607477697993-0",
{
"data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d"}'
},
),
]
rv = self.fetch_events()
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
expected = {
"result": [
{
"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
"errors": [],
"id": "1607477697866-0",
"job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512",
"result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463",
"status": "done",
"user_id": "1",
},
{
"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
"errors": [],
"id": "1607477697993-0",
"job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c",
"result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d",
"status": "done",
"user_id": "1",
},
]
}
self.assertEqual(response, expected)
def test_events_no_login(self):
async_query_manager.init_app(app)
rv = self.fetch_events()
assert rv.status_code == 401
def test_events_no_token(self):
self.login(username="admin")
self.client.set_cookie(
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], ""
)
rv = self.fetch_events()
assert rv.status_code == 401

View File

@ -35,6 +35,7 @@ class TestCache(SupersetTestCase):
cache_manager.data_cache.clear()
def test_no_data_cache(self):
data_cache_config = app.config["DATA_CACHE_CONFIG"]
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
cache_manager.init_app(app)
@ -48,11 +49,15 @@ class TestCache(SupersetTestCase):
resp_from_cache = self.get_json_resp(
json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
)
# restore DATA_CACHE_CONFIG
app.config["DATA_CACHE_CONFIG"] = data_cache_config
self.assertFalse(resp["is_cached"])
self.assertFalse(resp_from_cache["is_cached"])
def test_slice_data_cache(self):
# Override cache config
data_cache_config = app.config["DATA_CACHE_CONFIG"]
cache_default_timeout = app.config["CACHE_DEFAULT_TIMEOUT"]
app.config["CACHE_DEFAULT_TIMEOUT"] = 100
app.config["DATA_CACHE_CONFIG"] = {
"CACHE_TYPE": "simple",
@ -87,5 +92,6 @@ class TestCache(SupersetTestCase):
self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"]))
# reset cache config
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
app.config["DATA_CACHE_CONFIG"] = data_cache_config
app.config["CACHE_DEFAULT_TIMEOUT"] = cache_default_timeout
cache_manager.init_app(app)

View File

@ -31,21 +31,21 @@ from sqlalchemy import and_
from sqlalchemy.sql import func
from tests.test_app import app
from superset.connectors.sqla.models import SqlaTable
from superset.utils.core import AnnotationType, get_example_database
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from superset.charts.commands.data import ChartDataCommand
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import async_query_manager, cache_manager, db, security_manager
from superset.models.annotations import AnnotationLayer
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.reports import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.utils.core import AnnotationType, get_example_database
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.base_tests import SupersetTestCase, post_assert_metric, test_client
from tests.fixtures.importexport import (
chart_config,
chart_metadata_config,
@ -53,6 +53,7 @@ from tests.fixtures.importexport import (
dataset_config,
dataset_metadata_config,
)
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from tests.annotation_layers.fixtures import create_annotation_layers
@ -99,6 +100,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
db.session.commit()
return slice
@pytest.fixture(autouse=True)
def clear_data_cache(self):
with app.app_context():
cache_manager.data_cache.clear()
yield
@pytest.fixture()
def create_charts(self):
with self.create_app().app_context():
@ -1287,6 +1294,155 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_chart_data_async(self):
"""
Chart data API: Test chart data query (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 202)
data = json.loads(rv.data.decode("utf-8"))
keys = list(data.keys())
self.assertCountEqual(
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_chart_data_async_results_type(self):
"""
Chart data API: Test chart data query non-JSON format (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = "results"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_chart_data_async_invalid_token(self):
"""
Chart data API: Test chart data query (async)
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
test_client.set_cookie(
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
)
rv = post_assert_metric(test_client, CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 401)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
def test_chart_data_cache(self, load_qc_mock):
"""
Chart data cache API: Test chart data async cache request
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
load_qc_mock.return_value = query_context
orig_run = ChartDataCommand.run
def mock_run(self, **kwargs):
assert kwargs["force_cached"] == True
# override force_cached to get result from DB
return orig_run(self, force_cached=False)
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["result"][0]["rowcount"], 45)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
def test_chart_data_cache_run_failed(self, load_qc_mock):
"""
Chart data cache API: Test chart data async cache request with run failure
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
load_qc_mock.return_value = query_context
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(data["message"], "Error loading data from cache")
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
def test_chart_data_cache_no_login(self, load_qc_mock):
"""
Chart data cache API: Test chart data async cache request (no login)
"""
async_query_manager.init_app(app)
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
load_qc_mock.return_value = query_context
orig_run = ChartDataCommand.run
def mock_run(self, **kwargs):
assert kwargs["force_cached"] == True
# override force_cached to get result from DB
return orig_run(self, force_cached=False)
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
self.assertEqual(rv.status_code, 401)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_chart_data_cache_key_error(self):
"""
Chart data cache API: Test chart data async cache request with invalid cache key
"""
async_query_manager.init_app(app)
self.login(username="admin")
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
self.assertEqual(rv.status_code, 404)
def test_export_chart(self):
"""
Chart API: Test export chart

View File

@ -52,6 +52,7 @@ from superset import (
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.extensions import async_query_manager
from superset.models import core as models
from superset.models.annotations import Annotation, AnnotationLayer
from superset.models.dashboard import Dashboard
@ -602,10 +603,13 @@ class TestCore(SupersetTestCase):
) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
def test_cache_logging(self):
store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
girls_slice = self.get_slice("Girls", db.session)
self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id))
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
assert ck.datasource_uid == f"{girls_slice.table.id}__table"
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
def test_shortner(self):
self.login(username="admin")
@ -841,6 +845,191 @@ class TestCore(SupersetTestCase):
"The datasource associated with this chart no longer exists",
)
def test_explore_json(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
self.login(username="admin")
rv = self.client.post(
"/superset/explore_json/", data={"form_data": json.dumps(form_data)},
)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["rowcount"], 2)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_explore_json_async(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
async_query_manager.init_app(app)
self.login(username="admin")
rv = self.client.post(
"/superset/explore_json/", data={"form_data": json.dumps(form_data)},
)
data = json.loads(rv.data.decode("utf-8"))
keys = list(data.keys())
self.assertEqual(rv.status_code, 202)
self.assertCountEqual(
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
)
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
GLOBAL_ASYNC_QUERIES=True,
)
def test_explore_json_async_results_format(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
async_query_manager.init_app(app)
self.login(username="admin")
rv = self.client.post(
"/superset/explore_json/?results=true",
data={"form_data": json.dumps(form_data)},
)
self.assertEqual(rv.status_code, 200)
@mock.patch(
"superset.utils.cache_manager.CacheManager.cache",
new_callable=mock.PropertyMock,
)
@mock.patch("superset.viz.BaseViz.force_cached", new_callable=mock.PropertyMock)
def test_explore_json_data(self, mock_force_cached, mock_cache):
tbl_id = self.table_ids.get("birth_names")
form_data = dict(
{
"form_data": {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
}
)
class MockCache:
def get(self, key):
return form_data
def set(self):
return None
mock_cache.return_value = MockCache()
mock_force_cached.return_value = False
self.login(username="admin")
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data["rowcount"], 2)
@mock.patch(
"superset.utils.cache_manager.CacheManager.cache",
new_callable=mock.PropertyMock,
)
def test_explore_json_data_no_login(self, mock_cache):
tbl_id = self.table_ids.get("birth_names")
form_data = dict(
{
"form_data": {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
}
)
class MockCache:
def get(self, key):
return form_data
def set(self):
return None
mock_cache.return_value = MockCache()
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
self.assertEqual(rv.status_code, 401)
def test_explore_json_data_invalid_cache_key(self):
self.login(username="admin")
cache_key = "invalid-cache-key"
rv = self.client.get(f"/superset/explore_json/data/{cache_key}")
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 404)
self.assertEqual(data["error"], "Cached data not found")
@mock.patch(
"superset.security.SupersetSecurityManager.get_schemas_accessible_by_user"
)

View File

@ -82,7 +82,7 @@ class TestQueryContext(SupersetTestCase):
# construct baseline cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
cache_key_original = query_context.query_cache_key(query_object)
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
@ -99,7 +99,7 @@ class TestQueryContext(SupersetTestCase):
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
cache_key_new = query_context.query_cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
@ -115,20 +115,20 @@ class TestQueryContext(SupersetTestCase):
# construct baseline cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
cache_key_original = query_context.query_cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
cache_key_with_null = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
cache_key_without_post_processing = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
def test_query_context_time_range_endpoints(self):
@ -179,13 +179,10 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
data = responses["queries"][0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
assert ck.datasource_uid == f"{table.id}__table"
def test_sql_injection_via_groupby(self):
"""
Ensure that calling invalid columns names in groupby are caught
@ -197,7 +194,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["groupby"] = ["currentDatabase()"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
assert query_payload["queries"][0].get("error") is not None
def test_sql_injection_via_columns(self):
"""
@ -212,7 +209,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["columns"] = ["*, 'extra'"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
assert query_payload["queries"][0].get("error") is not None
def test_sql_injection_via_metrics(self):
"""
@ -233,7 +230,7 @@ class TestQueryContext(SupersetTestCase):
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
assert query_payload["queries"][0].get("error") is not None
def test_samples_response_type(self):
"""
@ -248,7 +245,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
data = responses["queries"][0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])
@ -265,7 +262,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]
response = responses["queries"][0]
self.assertEqual(len(response), 2)
self.assertEqual(response["language"], "sql")
self.assertIn("SELECT", response["query"])

View File

@ -22,7 +22,7 @@ from tests.superset_test_custom_template_processors import CustomPrestoTemplateP
AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
DEBUG = True
DEBUG = False
SUPERSET_WEBSERVER_PORT = 8081
# Allowing SQLALCHEMY_DATABASE_URI and SQLALCHEMY_EXAMPLES_URI to be defined as an env vars for
@ -96,6 +96,8 @@ DATA_CACHE_CONFIG = {
"CACHE_KEY_PREFIX": "superset_data_cache",
}
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me-test-secret-change-me"
class CeleryConfig(object):
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"

16
tests/tasks/__init__.py Normal file
View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for async query celery jobs in Superset"""
import re
from unittest import mock
from uuid import uuid4
import pytest
from superset import db
from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
from tests.test_app import app
def get_table_by_name(name: str) -> SqlaTable:
with app.app_context():
return db.session.query(SqlaTable).filter_by(table_name=name).one()
class TestAsyncQueries(SupersetTestCase):
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = get_query_context(table.name, table.id, table.type)
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"status": "pending",
"errors": [],
}
load_chart_data_into_cache(job_metadata, form_data)
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
@mock.patch.object(
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
)
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = get_query_context(table.name, table.id, table.type)
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"status": "pending",
"errors": [],
}
with pytest.raises(ChartDataQueryFailedError):
load_chart_data_into_cache(job_metadata, form_data)
mock_run_command.assert_called_with(cache=True)
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{table.id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
"granularity_sqla": "ds",
"time_range": "No filter",
"metrics": ["count"],
"adhoc_filters": [],
"groupby": ["gender"],
"row_limit": 100,
}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"status": "pending",
"errors": [],
}
load_explore_json_into_cache(job_metadata, form_data)
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job):
async_query_manager.init_app(app)
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": 1,
"status": "pending",
"errors": [],
}
with pytest.raises(SupersetException):
load_explore_json_into_cache(job_metadata, form_data)
errors = ["The datasource associated with this chart no longer exists"]
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)

View File

@ -163,9 +163,20 @@ class TestBaseViz(SupersetTestCase):
datasource.database.cache_timeout = 1666
self.assertEqual(1666, test_viz.cache_timeout)
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"],
test_viz.cache_timeout,
)
data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None
datasource.database.cache_timeout = None
test_viz = viz.BaseViz(datasource, form_data={})
self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout)
# restore DATA_CACHE_CONFIG timeout
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
class TestTableViz(SupersetTestCase):