mirror of https://github.com/apache/superset.git
feat(SIP-39): Async query support for charts (#11499)
* Generate JWT in Flask app * Refactor chart data API query logic, add JWT validation and async worker * Add redis stream implementation, refactoring * Add chart data cache endpoint, refactor QueryContext caching * Typing, linting, refactoring * pytest fixes and openapi schema update * Enforce caching be configured for async query init * Async query processing for explore_json endpoint * Add /api/v1/async_event endpoint * Async frontend for dashboards [WIP] * Chart async error message support, refactoring * Abstract asyncEvent middleware * Async chart loading for Explore * Pylint fixes * asyncEvent middleware -> TypeScript, JS linting * Chart data API: enforce forced_cache, add tests * Add tests for explore_json endpoints * Add test for chart data cache enpoint (no login) * Consolidate set_and_log_cache and add STORE_CACHE_KEYS_IN_METADATA_DB flag * Add tests for tasks/async_queries and address PR comments * Bypass non-JSON result formats for async queries * Add tests for redux middleware * Remove debug statement Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com> * Skip force_cached if no queryObj * SunburstViz: don't modify self.form_data * Fix failing annotation test * Resolve merge/lint issues * Reduce polling delay * Fix new getClientErrorObject reference * Fix flakey unit tests * /api/v1/async_event: increment redis stream ID, add tests * PR feedback: refactoring, configuration * Fixup: remove debugging * Fix typescript errors due to redux upgrade * Update UPDATING.md * Fix failing py tests * asyncEvent_spec.js -> asyncEvent_spec.ts * Refactor flakey Python 3.7 mock assertions * Fix another shared state issue in Py tests * Use 'sub' claim in JWT for user_id * Refactor async middleware config * Fixup: restore FeatureFlag boolean type Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
parent
0fdf026cbc
commit
4d329071a1
|
@ -623,6 +623,11 @@ cd superset-frontend
|
|||
npm run test
|
||||
```
|
||||
|
||||
To run a single test file:
|
||||
```bash
|
||||
npm run test -- path/to/file.js
|
||||
```
|
||||
|
||||
### Integration Testing
|
||||
|
||||
We use [Cypress](https://www.cypress.io/) for integration tests. Tests can be run by `tox -e cypress`. To open Cypress and explore tests first setup and run test server:
|
||||
|
|
|
@ -23,7 +23,7 @@ This file documents any backwards-incompatible changes in Superset and
|
|||
assists people when migrating to a new version.
|
||||
|
||||
## Next
|
||||
|
||||
- [11499](https://github.com/apache/incubator-superset/pull/11499): Breaking change: `STORE_CACHE_KEYS_IN_METADATA_DB` config flag added (default=`False`) to write `CacheKey` records to the metadata DB. `CacheKey` recording was enabled by default previously.
|
||||
- [11920](https://github.com/apache/incubator-superset/pull/11920): Undos the DB migration from [11714](https://github.com/apache/incubator-superset/pull/11714) to prevent adding new columns to the logs table. Deploying a sha between these two PRs may result in locking your DB.
|
||||
- [11704](https://github.com/apache/incubator-superset/pull/11704) Breaking change: Jinja templating for SQL queries has been updated, removing default modules such as `datetime` and `random` and enforcing static template values. To restore or extend functionality, use `JINJA_CONTEXT_ADDONS` and `CUSTOM_TEMPLATE_PROCESSORS` in `superset_config.py`.
|
||||
- [11714](https://github.com/apache/incubator-superset/pull/11714): Logs
|
||||
|
|
|
@ -30,7 +30,7 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,polyline,prison,pyarrow,pyhive,pytest,pytz,redis,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import fetchMock from 'fetch-mock';
|
||||
import sinon from 'sinon';
|
||||
import * as featureFlags from 'src/featureFlags';
|
||||
import initAsyncEvents from 'src/middleware/asyncEvent';
|
||||
|
||||
jest.useFakeTimers();
|
||||
|
||||
describe('asyncEvent middleware', () => {
|
||||
const next = sinon.spy();
|
||||
const state = {
|
||||
charts: {
|
||||
123: {
|
||||
id: 123,
|
||||
status: 'loading',
|
||||
asyncJobId: 'foo123',
|
||||
},
|
||||
345: {
|
||||
id: 345,
|
||||
status: 'loading',
|
||||
asyncJobId: 'foo345',
|
||||
},
|
||||
},
|
||||
};
|
||||
const events = [
|
||||
{
|
||||
status: 'done',
|
||||
result_url: '/api/v1/chart/data/cache-key-1',
|
||||
job_id: 'foo123',
|
||||
channel_id: '999',
|
||||
errors: [],
|
||||
},
|
||||
{
|
||||
status: 'done',
|
||||
result_url: '/api/v1/chart/data/cache-key-2',
|
||||
job_id: 'foo345',
|
||||
channel_id: '999',
|
||||
errors: [],
|
||||
},
|
||||
];
|
||||
const mockStore = {
|
||||
getState: () => state,
|
||||
dispatch: sinon.stub(),
|
||||
};
|
||||
const action = {
|
||||
type: 'GENERIC_ACTION',
|
||||
};
|
||||
const EVENTS_ENDPOINT = 'glob:*/api/v1/async_event/*';
|
||||
const CACHED_DATA_ENDPOINT = 'glob:*/api/v1/chart/data/*';
|
||||
const config = {
|
||||
GLOBAL_ASYNC_QUERIES_TRANSPORT: 'polling',
|
||||
GLOBAL_ASYNC_QUERIES_POLLING_DELAY: 500,
|
||||
};
|
||||
let featureEnabledStub: any;
|
||||
|
||||
function setup() {
|
||||
const getPendingComponents = sinon.stub();
|
||||
const successAction = sinon.spy();
|
||||
const errorAction = sinon.spy();
|
||||
const testCallback = sinon.stub();
|
||||
const testCallbackPromise = sinon.stub();
|
||||
testCallbackPromise.returns(
|
||||
new Promise(resolve => {
|
||||
testCallback.callsFake(resolve);
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
testCallback,
|
||||
testCallbackPromise,
|
||||
};
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
fetchMock.get(EVENTS_ENDPOINT, {
|
||||
status: 200,
|
||||
body: { result: [] },
|
||||
});
|
||||
fetchMock.get(CACHED_DATA_ENDPOINT, {
|
||||
status: 200,
|
||||
body: { result: { some: 'data' } },
|
||||
});
|
||||
featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
|
||||
featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(true);
|
||||
});
|
||||
afterEach(() => {
|
||||
fetchMock.reset();
|
||||
next.resetHistory();
|
||||
featureEnabledStub.restore();
|
||||
});
|
||||
afterAll(fetchMock.reset);
|
||||
|
||||
it('should initialize and call next', () => {
|
||||
const { getPendingComponents, successAction, errorAction } = setup();
|
||||
getPendingComponents.returns([]);
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
});
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
expect(next.callCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should fetch events when there are pending components', () => {
|
||||
const {
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
testCallback,
|
||||
testCallbackPromise,
|
||||
} = setup();
|
||||
getPendingComponents.returns(Object.values(state.charts));
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
processEventsCallback: testCallback,
|
||||
});
|
||||
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
|
||||
return testCallbackPromise().then(() => {
|
||||
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should fetch cached when there are successful events', () => {
|
||||
const {
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
testCallback,
|
||||
testCallbackPromise,
|
||||
} = setup();
|
||||
fetchMock.reset();
|
||||
fetchMock.get(EVENTS_ENDPOINT, {
|
||||
status: 200,
|
||||
body: { result: events },
|
||||
});
|
||||
fetchMock.get(CACHED_DATA_ENDPOINT, {
|
||||
status: 200,
|
||||
body: { result: { some: 'data' } },
|
||||
});
|
||||
getPendingComponents.returns(Object.values(state.charts));
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
processEventsCallback: testCallback,
|
||||
});
|
||||
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
|
||||
return testCallbackPromise().then(() => {
|
||||
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
|
||||
expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
|
||||
expect(successAction.callCount).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
it('should call errorAction for cache fetch error responses', () => {
|
||||
const {
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
testCallback,
|
||||
testCallbackPromise,
|
||||
} = setup();
|
||||
fetchMock.reset();
|
||||
fetchMock.get(EVENTS_ENDPOINT, {
|
||||
status: 200,
|
||||
body: { result: events },
|
||||
});
|
||||
fetchMock.get(CACHED_DATA_ENDPOINT, {
|
||||
status: 400,
|
||||
body: { errors: ['error'] },
|
||||
});
|
||||
getPendingComponents.returns(Object.values(state.charts));
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
processEventsCallback: testCallback,
|
||||
});
|
||||
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
|
||||
return testCallbackPromise().then(() => {
|
||||
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
|
||||
expect(fetchMock.calls(CACHED_DATA_ENDPOINT)).toHaveLength(2);
|
||||
expect(errorAction.callCount).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle event fetching error responses', () => {
|
||||
const {
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
testCallback,
|
||||
testCallbackPromise,
|
||||
} = setup();
|
||||
fetchMock.reset();
|
||||
fetchMock.get(EVENTS_ENDPOINT, {
|
||||
status: 400,
|
||||
body: { message: 'error' },
|
||||
});
|
||||
getPendingComponents.returns(Object.values(state.charts));
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
processEventsCallback: testCallback,
|
||||
});
|
||||
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
|
||||
return testCallbackPromise().then(() => {
|
||||
expect(fetchMock.calls(EVENTS_ENDPOINT)).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not fetch events when async queries are disabled', () => {
|
||||
featureEnabledStub.restore();
|
||||
featureEnabledStub = sinon.stub(featureFlags, 'isFeatureEnabled');
|
||||
featureEnabledStub.withArgs('GLOBAL_ASYNC_QUERIES').returns(false);
|
||||
const { getPendingComponents, successAction, errorAction } = setup();
|
||||
getPendingComponents.returns(Object.values(state.charts));
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
});
|
||||
|
||||
asyncEventMiddleware(mockStore)(next)(action);
|
||||
expect(getPendingComponents.called).toBe(false);
|
||||
});
|
||||
});
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
import { ErrorTypeEnum } from 'src/components/ErrorMessage/types';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
|
||||
describe('getClientErrorObject()', () => {
|
||||
it('Returns a Promise', () => {
|
||||
|
|
|
@ -30,7 +30,7 @@ import {
|
|||
addSuccessToast as addSuccessToastAction,
|
||||
addWarningToast as addWarningToastAction,
|
||||
} from '../../messageToasts/actions/index';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
import COMMON_ERR_MESSAGES from '../../utils/errorMessages';
|
||||
|
||||
export const RESET_STATE = 'RESET_STATE';
|
||||
|
|
|
@ -25,7 +25,7 @@ import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
|
|||
import Button from 'src/components/Button';
|
||||
import CopyToClipboard from '../../components/CopyToClipboard';
|
||||
import { storeQuery } from '../../utils/common';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
import withToasts from '../../messageToasts/enhancers/withToasts';
|
||||
|
||||
const propTypes = {
|
||||
|
|
|
@ -38,7 +38,7 @@ import {
|
|||
import { addDangerToast } from '../messageToasts/actions';
|
||||
import { logEvent } from '../logger/actions';
|
||||
import { Logger, LOG_ACTIONS_LOAD_CHART } from '../logger/LogUtils';
|
||||
import getClientErrorObject from '../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../utils/getClientErrorObject';
|
||||
import { allowCrossDomain as domainShardingEnabled } from '../utils/hostNamesConfig';
|
||||
|
||||
export const CHART_UPDATE_STARTED = 'CHART_UPDATE_STARTED';
|
||||
|
@ -66,6 +66,11 @@ export function chartUpdateFailed(queryResponse, key) {
|
|||
return { type: CHART_UPDATE_FAILED, queryResponse, key };
|
||||
}
|
||||
|
||||
export const CHART_UPDATE_QUEUED = 'CHART_UPDATE_QUEUED';
|
||||
export function chartUpdateQueued(asyncJobMeta, key) {
|
||||
return { type: CHART_UPDATE_QUEUED, asyncJobMeta, key };
|
||||
}
|
||||
|
||||
export const CHART_RENDERING_FAILED = 'CHART_RENDERING_FAILED';
|
||||
export function chartRenderingFailed(error, key, stackTrace) {
|
||||
return { type: CHART_RENDERING_FAILED, error, key, stackTrace };
|
||||
|
@ -356,6 +361,12 @@ export function exploreJSON(
|
|||
|
||||
const chartDataRequestCaught = chartDataRequest
|
||||
.then(response => {
|
||||
if (isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES)) {
|
||||
// deal with getChartDataRequest transforming the response data
|
||||
const result = 'result' in response ? response.result[0] : response;
|
||||
return dispatch(chartUpdateQueued(result, key));
|
||||
}
|
||||
|
||||
// new API returns an object with an array of restults
|
||||
// problem: response holds a list of results, when before we were just getting one result.
|
||||
// How to make the entire app compatible with multiple results?
|
||||
|
|
|
@ -71,6 +71,14 @@ export default function chartReducer(charts = {}, action) {
|
|||
chartUpdateEndTime: now(),
|
||||
};
|
||||
},
|
||||
[actions.CHART_UPDATE_QUEUED](state) {
|
||||
return {
|
||||
...state,
|
||||
asyncJobId: action.asyncJobMeta.job_id,
|
||||
chartStatus: 'loading',
|
||||
chartUpdateEndTime: now(),
|
||||
};
|
||||
},
|
||||
[actions.CHART_RENDERING_SUCCEEDED](state) {
|
||||
return { ...state, chartStatus: 'rendered', chartUpdateEndTime: now() };
|
||||
},
|
||||
|
|
|
@ -21,7 +21,7 @@ import PropTypes from 'prop-types';
|
|||
// TODO: refactor this with `import { AsyncSelect } from src/components/Select`
|
||||
import { Select } from 'src/components/Select';
|
||||
import { t, SupersetClient } from '@superset-ui/core';
|
||||
import getClientErrorObject from '../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../utils/getClientErrorObject';
|
||||
|
||||
const propTypes = {
|
||||
dataEndpoint: PropTypes.string.isRequired,
|
||||
|
|
|
@ -29,7 +29,7 @@ import {
|
|||
updateDirectPathToFilter,
|
||||
} from './dashboardFilters';
|
||||
import { applyDefaultFormData } from '../../explore/store';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
import { SAVE_TYPE_OVERWRITE } from '../util/constants';
|
||||
import {
|
||||
addSuccessToast,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
* under the License.
|
||||
*/
|
||||
import { SupersetClient } from '@superset-ui/core';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
|
||||
export const SET_DATASOURCE = 'SET_DATASOURCE';
|
||||
export function setDatasource(datasource, key) {
|
||||
|
|
|
@ -22,7 +22,7 @@ import rison from 'rison';
|
|||
|
||||
import { addDangerToast } from 'src/messageToasts/actions';
|
||||
import { getDatasourceParameter } from 'src/modules/utils';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
|
||||
export const SET_ALL_SLICES = 'SET_ALL_SLICES';
|
||||
export function setAllSlices(slices) {
|
||||
|
|
|
@ -35,7 +35,7 @@ import FormLabel from 'src/components/FormLabel';
|
|||
import { JsonEditor } from 'src/components/AsyncAceEditor';
|
||||
|
||||
import ColorSchemeControlWrapper from 'src/dashboard/components/ColorSchemeControlWrapper';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
import withToasts from '../../messageToasts/enhancers/withToasts';
|
||||
import '../stylesheets/buttons.less';
|
||||
|
||||
|
|
|
@ -24,7 +24,9 @@ import { initFeatureFlags } from 'src/featureFlags';
|
|||
import { initEnhancer } from '../reduxUtils';
|
||||
import getInitialState from './reducers/getInitialState';
|
||||
import rootReducer from './reducers/index';
|
||||
import initAsyncEvents from '../middleware/asyncEvent';
|
||||
import logger from '../middleware/loggerMiddleware';
|
||||
import * as actions from '../chart/chartAction';
|
||||
|
||||
import App from './App';
|
||||
|
||||
|
@ -33,10 +35,23 @@ const bootstrapData = JSON.parse(appContainer.getAttribute('data-bootstrap'));
|
|||
initFeatureFlags(bootstrapData.common.feature_flags);
|
||||
const initState = getInitialState(bootstrapData);
|
||||
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config: bootstrapData.common.conf,
|
||||
getPendingComponents: ({ charts }) =>
|
||||
Object.values(charts).filter(c => c.chartStatus === 'loading'),
|
||||
successAction: (componentId, componentData) =>
|
||||
actions.chartUpdateSucceeded(componentData, componentId),
|
||||
errorAction: (componentId, response) =>
|
||||
actions.chartUpdateFailed(response, componentId),
|
||||
});
|
||||
|
||||
const store = createStore(
|
||||
rootReducer,
|
||||
initState,
|
||||
compose(applyMiddleware(thunk, logger), initEnhancer(false)),
|
||||
compose(
|
||||
applyMiddleware(thunk, logger, asyncEventMiddleware),
|
||||
initEnhancer(false),
|
||||
),
|
||||
);
|
||||
|
||||
ReactDOM.render(<App store={store} />, document.getElementById('app'));
|
||||
|
|
|
@ -27,7 +27,7 @@ import { Alert, FormControl, FormControlProps } from 'react-bootstrap';
|
|||
import { SupersetClient, t } from '@superset-ui/core';
|
||||
import TableView from 'src/components/TableView';
|
||||
import Modal from 'src/common/components/Modal';
|
||||
import getClientErrorObject from '../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../utils/getClientErrorObject';
|
||||
import Loading from '../components/Loading';
|
||||
import withToasts from '../messageToasts/enhancers/withToasts';
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ import Loading from 'src/components/Loading';
|
|||
import TableSelector from 'src/components/TableSelector';
|
||||
import EditableTitle from 'src/components/EditableTitle';
|
||||
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
|
||||
import CheckboxControl from 'src/explore/components/controls/CheckboxControl';
|
||||
import TextControl from 'src/explore/components/controls/TextControl';
|
||||
|
|
|
@ -25,7 +25,7 @@ import Modal from 'src/common/components/Modal';
|
|||
import AsyncEsmComponent from 'src/components/AsyncEsmComponent';
|
||||
import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
|
||||
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import withToasts from 'src/messageToasts/enhancers/withToasts';
|
||||
|
||||
const DatasourceEditor = AsyncEsmComponent(() => import('./DatasourceEditor'));
|
||||
|
|
|
@ -23,7 +23,7 @@ import Tabs from 'src/common/components/Tabs';
|
|||
import Loading from 'src/components/Loading';
|
||||
import TableView, { EmptyWrapperType } from 'src/components/TableView';
|
||||
import { getChartDataRequest } from 'src/chart/chartAction';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import {
|
||||
CopyToClipboardButton,
|
||||
FilterInput,
|
||||
|
|
|
@ -30,7 +30,7 @@ import { DropdownButton } from 'react-bootstrap';
|
|||
import { styled, t } from '@superset-ui/core';
|
||||
|
||||
import { Menu } from 'src/common/components';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
import CopyToClipboard from '../../components/CopyToClipboard';
|
||||
import { getChartDataRequest } from '../../chart/chartAction';
|
||||
import downloadAsImage from '../../utils/downloadAsImage';
|
||||
|
|
|
@ -32,7 +32,7 @@ import rison from 'rison';
|
|||
import { t, SupersetClient } from '@superset-ui/core';
|
||||
import Chart, { Slice } from 'src/types/Chart';
|
||||
import FormLabel from 'src/components/FormLabel';
|
||||
import getClientErrorObject from '../../utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from '../../utils/getClientErrorObject';
|
||||
|
||||
type PropertiesModalProps = {
|
||||
slice: Slice;
|
||||
|
|
|
@ -25,6 +25,8 @@ import { initFeatureFlags } from '../featureFlags';
|
|||
import { initEnhancer } from '../reduxUtils';
|
||||
import getInitialState from './reducers/getInitialState';
|
||||
import rootReducer from './reducers/index';
|
||||
import initAsyncEvents from '../middleware/asyncEvent';
|
||||
import * as actions from '../chart/chartAction';
|
||||
|
||||
import App from './App';
|
||||
|
||||
|
@ -35,10 +37,23 @@ const bootstrapData = JSON.parse(
|
|||
initFeatureFlags(bootstrapData.common.feature_flags);
|
||||
const initState = getInitialState(bootstrapData);
|
||||
|
||||
const asyncEventMiddleware = initAsyncEvents({
|
||||
config: bootstrapData.common.conf,
|
||||
getPendingComponents: ({ charts }) =>
|
||||
Object.values(charts).filter(c => c.chartStatus === 'loading'),
|
||||
successAction: (componentId, componentData) =>
|
||||
actions.chartUpdateSucceeded(componentData, componentId),
|
||||
errorAction: (componentId, response) =>
|
||||
actions.chartUpdateFailed(response, componentId),
|
||||
});
|
||||
|
||||
const store = createStore(
|
||||
rootReducer,
|
||||
initState,
|
||||
compose(applyMiddleware(thunk, logger), initEnhancer(false)),
|
||||
compose(
|
||||
applyMiddleware(thunk, logger, asyncEventMiddleware),
|
||||
initEnhancer(false),
|
||||
),
|
||||
);
|
||||
|
||||
ReactDOM.render(<App store={store} />, document.getElementById('app'));
|
||||
|
|
|
@ -34,6 +34,7 @@ export enum FeatureFlag {
|
|||
DISPLAY_MARKDOWN_HTML = 'DISPLAY_MARKDOWN_HTML',
|
||||
ESCAPE_MARKDOWN_HTML = 'ESCAPE_MARKDOWN_HTML',
|
||||
VERSIONED_EXPORT = 'VERSIONED_EXPORT',
|
||||
GLOBAL_ASYNC_QUERIES = 'GLOBAL_ASYNC_QUERIES',
|
||||
}
|
||||
|
||||
export type FeatureFlagMap = {
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { Middleware, MiddlewareAPI, Dispatch } from 'redux';
|
||||
import { makeApi, SupersetClient } from '@superset-ui/core';
|
||||
import { SupersetError } from 'src/components/ErrorMessage/types';
|
||||
import { isFeatureEnabled, FeatureFlag } from '../featureFlags';
|
||||
import {
|
||||
getClientErrorObject,
|
||||
parseErrorJson,
|
||||
} from '../utils/getClientErrorObject';
|
||||
|
||||
export type AsyncEvent = {
|
||||
id: string;
|
||||
channel_id: string;
|
||||
job_id: string;
|
||||
user_id: string;
|
||||
status: string;
|
||||
errors: SupersetError[];
|
||||
result_url: string;
|
||||
};
|
||||
|
||||
type AsyncEventOptions = {
|
||||
config: {
|
||||
GLOBAL_ASYNC_QUERIES_TRANSPORT: string;
|
||||
GLOBAL_ASYNC_QUERIES_POLLING_DELAY: number;
|
||||
};
|
||||
getPendingComponents: (state: any) => any[];
|
||||
successAction: (componentId: number, componentData: any) => { type: string };
|
||||
errorAction: (componentId: number, response: any) => { type: string };
|
||||
processEventsCallback?: (events: AsyncEvent[]) => void; // this is currently used only for tests
|
||||
};
|
||||
|
||||
type CachedDataResponse = {
|
||||
componentId: number;
|
||||
status: string;
|
||||
data: any;
|
||||
};
|
||||
|
||||
const initAsyncEvents = (options: AsyncEventOptions) => {
|
||||
// TODO: implement websocket support
|
||||
const TRANSPORT_POLLING = 'polling';
|
||||
const {
|
||||
config,
|
||||
getPendingComponents,
|
||||
successAction,
|
||||
errorAction,
|
||||
processEventsCallback,
|
||||
} = options;
|
||||
const transport = config.GLOBAL_ASYNC_QUERIES_TRANSPORT || TRANSPORT_POLLING;
|
||||
const polling_delay = config.GLOBAL_ASYNC_QUERIES_POLLING_DELAY || 500;
|
||||
|
||||
const middleware: Middleware = (store: MiddlewareAPI) => (next: Dispatch) => {
|
||||
const JOB_STATUS = {
|
||||
PENDING: 'pending',
|
||||
RUNNING: 'running',
|
||||
ERROR: 'error',
|
||||
DONE: 'done',
|
||||
};
|
||||
const LOCALSTORAGE_KEY = 'last_async_event_id';
|
||||
const POLLING_URL = '/api/v1/async_event/';
|
||||
let lastReceivedEventId: string | null;
|
||||
|
||||
try {
|
||||
lastReceivedEventId = localStorage.getItem(LOCALSTORAGE_KEY);
|
||||
} catch (err) {
|
||||
console.warn('failed to fetch last event Id from localStorage');
|
||||
}
|
||||
|
||||
const fetchEvents = makeApi<
|
||||
{ last_id?: string | null },
|
||||
{ result: AsyncEvent[] }
|
||||
>({
|
||||
method: 'GET',
|
||||
endpoint: POLLING_URL,
|
||||
});
|
||||
|
||||
const fetchCachedData = async (
|
||||
asyncEvent: AsyncEvent,
|
||||
componentId: number,
|
||||
): Promise<CachedDataResponse> => {
|
||||
let status = 'success';
|
||||
let data;
|
||||
try {
|
||||
const { json } = await SupersetClient.get({
|
||||
endpoint: asyncEvent.result_url,
|
||||
});
|
||||
data = 'result' in json ? json.result[0] : json;
|
||||
} catch (response) {
|
||||
status = 'error';
|
||||
data = await getClientErrorObject(response);
|
||||
}
|
||||
|
||||
return { componentId, status, data };
|
||||
};
|
||||
|
||||
const setLastId = (asyncEvent: AsyncEvent) => {
|
||||
lastReceivedEventId = asyncEvent.id;
|
||||
try {
|
||||
localStorage.setItem(LOCALSTORAGE_KEY, lastReceivedEventId as string);
|
||||
} catch (err) {
|
||||
console.warn('Error saving event ID to localStorage', err);
|
||||
}
|
||||
};
|
||||
|
||||
const processEvents = async () => {
|
||||
const state = store.getState();
|
||||
const queuedComponents = getPendingComponents(state);
|
||||
const eventArgs = lastReceivedEventId
|
||||
? { last_id: lastReceivedEventId }
|
||||
: {};
|
||||
const events: AsyncEvent[] = [];
|
||||
if (queuedComponents && queuedComponents.length) {
|
||||
try {
|
||||
const { result: events } = await fetchEvents(eventArgs);
|
||||
if (events && events.length) {
|
||||
const componentsByJobId = queuedComponents.reduce((acc, item) => {
|
||||
acc[item.asyncJobId] = item;
|
||||
return acc;
|
||||
}, {});
|
||||
const fetchDataEvents: Promise<CachedDataResponse>[] = [];
|
||||
events.forEach((asyncEvent: AsyncEvent) => {
|
||||
const component = componentsByJobId[asyncEvent.job_id];
|
||||
if (!component) {
|
||||
console.warn(
|
||||
'component not found for job_id',
|
||||
asyncEvent.job_id,
|
||||
);
|
||||
return setLastId(asyncEvent);
|
||||
}
|
||||
const componentId = component.id;
|
||||
switch (asyncEvent.status) {
|
||||
case JOB_STATUS.DONE:
|
||||
fetchDataEvents.push(
|
||||
fetchCachedData(asyncEvent, componentId),
|
||||
);
|
||||
break;
|
||||
case JOB_STATUS.ERROR:
|
||||
store.dispatch(
|
||||
errorAction(componentId, parseErrorJson(asyncEvent)),
|
||||
);
|
||||
break;
|
||||
default:
|
||||
console.warn('received event with status', asyncEvent.status);
|
||||
}
|
||||
|
||||
return setLastId(asyncEvent);
|
||||
});
|
||||
|
||||
const fetchResults = await Promise.all(fetchDataEvents);
|
||||
fetchResults.forEach(result => {
|
||||
if (result.status === 'success') {
|
||||
store.dispatch(successAction(result.componentId, result.data));
|
||||
} else {
|
||||
store.dispatch(errorAction(result.componentId, result.data));
|
||||
}
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(err);
|
||||
}
|
||||
}
|
||||
|
||||
if (processEventsCallback) processEventsCallback(events);
|
||||
|
||||
return setTimeout(processEvents, polling_delay);
|
||||
};
|
||||
|
||||
if (
|
||||
isFeatureEnabled(FeatureFlag.GLOBAL_ASYNC_QUERIES) &&
|
||||
transport === TRANSPORT_POLLING
|
||||
)
|
||||
processEvents();
|
||||
|
||||
return action => next(action);
|
||||
};
|
||||
|
||||
return middleware;
|
||||
};
|
||||
|
||||
export default initAsyncEvents;
|
|
@ -19,7 +19,8 @@
|
|||
/* eslint global-require: 0 */
|
||||
import $ from 'jquery';
|
||||
import { SupersetClient } from '@superset-ui/core';
|
||||
import getClientErrorObject, {
|
||||
import {
|
||||
getClientErrorObject,
|
||||
ClientErrorObject,
|
||||
} from '../utils/getClientErrorObject';
|
||||
import setupErrorMessages from './setupErrorMessages';
|
||||
|
|
|
@ -21,7 +21,7 @@ import {
|
|||
getTimeFormatter,
|
||||
TimeFormats,
|
||||
} from '@superset-ui/core';
|
||||
import getClientErrorObject from './getClientErrorObject';
|
||||
import { getClientErrorObject } from './getClientErrorObject';
|
||||
|
||||
// ATTENTION: If you change any constants, make sure to also change constants.py
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { SupersetClientResponse, t } from '@superset-ui/core';
|
||||
import { JsonObject, SupersetClientResponse, t } from '@superset-ui/core';
|
||||
import {
|
||||
SupersetError,
|
||||
ErrorTypeEnum,
|
||||
|
@ -36,7 +36,33 @@ export type ClientErrorObject = {
|
|||
stacktrace?: string;
|
||||
} & Partial<SupersetClientResponse>;
|
||||
|
||||
export default function getClientErrorObject(
|
||||
export function parseErrorJson(responseObject: JsonObject): ClientErrorObject {
|
||||
let error = { ...responseObject };
|
||||
// Backwards compatibility for old error renderers with the new error object
|
||||
if (error.errors && error.errors.length > 0) {
|
||||
error.error = error.description = error.errors[0].message;
|
||||
error.link = error.errors[0]?.extra?.link;
|
||||
}
|
||||
|
||||
if (error.stack) {
|
||||
error = {
|
||||
...error,
|
||||
error:
|
||||
t('Unexpected error: ') +
|
||||
(error.description || t('(no description, click to see stack trace)')),
|
||||
stacktrace: error.stack,
|
||||
};
|
||||
} else if (error.responseText && error.responseText.indexOf('CSRF') >= 0) {
|
||||
error = {
|
||||
...error,
|
||||
error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
|
||||
};
|
||||
}
|
||||
|
||||
return { ...error, error: error.error }; // explicit ClientErrorObject
|
||||
}
|
||||
|
||||
export function getClientErrorObject(
|
||||
response: SupersetClientResponse | (Response & { timeout: number }) | string,
|
||||
): Promise<ClientErrorObject> {
|
||||
// takes a SupersetClientResponse as input, attempts to read response as Json if possible,
|
||||
|
@ -54,33 +80,8 @@ export default function getClientErrorObject(
|
|||
.clone()
|
||||
.json()
|
||||
.then(errorJson => {
|
||||
let error = { ...responseObject, ...errorJson };
|
||||
|
||||
// Backwards compatibility for old error renderers with the new error object
|
||||
if (error.errors && error.errors.length > 0) {
|
||||
error.error = error.description = error.errors[0].message;
|
||||
error.link = error.errors[0]?.extra?.link;
|
||||
}
|
||||
|
||||
if (error.stack) {
|
||||
error = {
|
||||
...error,
|
||||
error:
|
||||
t('Unexpected error: ') +
|
||||
(error.description ||
|
||||
t('(no description, click to see stack trace)')),
|
||||
stacktrace: error.stack,
|
||||
};
|
||||
} else if (
|
||||
error.responseText &&
|
||||
error.responseText.indexOf('CSRF') >= 0
|
||||
) {
|
||||
error = {
|
||||
...error,
|
||||
error: t(COMMON_ERR_MESSAGES.SESSION_TIMED_OUT),
|
||||
};
|
||||
}
|
||||
resolve(error);
|
||||
const error = { ...responseObject, ...errorJson };
|
||||
resolve(parseErrorJson(error));
|
||||
})
|
||||
.catch(() => {
|
||||
// fall back to reading as text
|
||||
|
|
|
@ -29,7 +29,7 @@ import ConfirmStatusChange from 'src/components/ConfirmStatusChange';
|
|||
import DeleteModal from 'src/components/DeleteModal';
|
||||
import ListView, { ListViewProps } from 'src/components/ListView';
|
||||
import SubMenu, { SubMenuProps } from 'src/components/Menu/SubMenu';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import withToasts from 'src/messageToasts/enhancers/withToasts';
|
||||
import { IconName } from 'src/components/Icon';
|
||||
import { useListViewResource } from 'src/views/CRUD/hooks';
|
||||
|
|
|
@ -21,7 +21,7 @@ import { styled, t, SupersetClient } from '@superset-ui/core';
|
|||
import InfoTooltip from 'src/common/components/InfoTooltip';
|
||||
import { useSingleViewResource } from 'src/views/CRUD/hooks';
|
||||
import withToasts from 'src/messageToasts/enhancers/withToasts';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import Icon from 'src/components/Icon';
|
||||
import Modal from 'src/common/components/Modal';
|
||||
import Tabs from 'src/common/components/Tabs';
|
||||
|
|
|
@ -25,7 +25,7 @@ import { FetchDataConfig } from 'src/components/ListView';
|
|||
import { FilterValue } from 'src/components/ListView/types';
|
||||
import Chart, { Slice } from 'src/types/Chart';
|
||||
import copyTextToClipboard from 'src/utils/copy';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import { FavoriteStatus } from './types';
|
||||
|
||||
interface ListViewResourceState<D extends object = any> {
|
||||
|
|
|
@ -25,7 +25,7 @@ import {
|
|||
} from '@superset-ui/core';
|
||||
import Chart from 'src/types/Chart';
|
||||
import rison from 'rison';
|
||||
import getClientErrorObject from 'src/utils/getClientErrorObject';
|
||||
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
|
||||
import { FetchDataConfig } from 'src/components/ListView';
|
||||
import { Dashboard } from './types';
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from superset.extensions import (
|
|||
_event_logger,
|
||||
APP_DIR,
|
||||
appbuilder,
|
||||
async_query_manager,
|
||||
cache_manager,
|
||||
celery_app,
|
||||
csrf,
|
||||
|
@ -127,6 +128,7 @@ class SupersetAppInitializer:
|
|||
# pylint: disable=too-many-branches
|
||||
from superset.annotation_layers.api import AnnotationLayerRestApi
|
||||
from superset.annotation_layers.annotations.api import AnnotationRestApi
|
||||
from superset.async_events.api import AsyncEventsRestApi
|
||||
from superset.cachekeys.api import CacheRestApi
|
||||
from superset.charts.api import ChartRestApi
|
||||
from superset.connectors.druid.views import (
|
||||
|
@ -201,6 +203,7 @@ class SupersetAppInitializer:
|
|||
#
|
||||
appbuilder.add_api(AnnotationRestApi)
|
||||
appbuilder.add_api(AnnotationLayerRestApi)
|
||||
appbuilder.add_api(AsyncEventsRestApi)
|
||||
appbuilder.add_api(CacheRestApi)
|
||||
appbuilder.add_api(ChartRestApi)
|
||||
appbuilder.add_api(CssTemplateRestApi)
|
||||
|
@ -498,6 +501,7 @@ class SupersetAppInitializer:
|
|||
self.configure_url_map_converters()
|
||||
self.configure_data_sources()
|
||||
self.configure_auth_provider()
|
||||
self.configure_async_queries()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
|
@ -648,6 +652,10 @@ class SupersetAppInitializer:
|
|||
for ex in csrf_exempt_list:
|
||||
csrf.exempt(ex)
|
||||
|
||||
def configure_async_queries(self) -> None:
|
||||
if feature_flag_manager.is_feature_enabled("GLOBAL_ASYNC_QUERIES"):
|
||||
async_query_manager.init_app(self.flask_app)
|
||||
|
||||
def register_blueprints(self) -> None:
|
||||
for bp in self.config["BLUEPRINTS"]:
|
||||
try:
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,99 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
|
||||
from flask import request, Response
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.api import BaseApi, safe
|
||||
from flask_appbuilder.security.decorators import permission_name, protect
|
||||
|
||||
from superset.extensions import async_query_manager, event_logger
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncEventsRestApi(BaseApi):
|
||||
resource_name = "async_event"
|
||||
allow_browser_login = True
|
||||
include_route_methods = {
|
||||
"events",
|
||||
}
|
||||
|
||||
@expose("/", methods=["GET"])
|
||||
@event_logger.log_this
|
||||
@protect()
|
||||
@safe
|
||||
@permission_name("list")
|
||||
def events(self) -> Response:
|
||||
"""
|
||||
Reads off of the Redis async events stream, using the user's JWT token and
|
||||
optional query params for last event received.
|
||||
---
|
||||
get:
|
||||
description: >-
|
||||
Reads off of the Redis events stream, using the user's JWT token and
|
||||
optional query params for last event received.
|
||||
parameters:
|
||||
- in: query
|
||||
name: last_id
|
||||
description: Last ID received by the client
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
200:
|
||||
description: Async event results
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
result:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
channel_id:
|
||||
type: string
|
||||
job_id:
|
||||
type: string
|
||||
user_id:
|
||||
type: integer
|
||||
status:
|
||||
type: string
|
||||
msg:
|
||||
type: string
|
||||
cache_key:
|
||||
type: string
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
async_channel_id = async_query_manager.parse_jwt_from_request(request)[
|
||||
"channel"
|
||||
]
|
||||
last_event_id = request.args.get("last_id")
|
||||
events = async_query_manager.read_events(async_channel_id, last_event_id)
|
||||
|
||||
except AsyncQueryTokenException:
|
||||
return self.response_401()
|
||||
|
||||
return self.response(200, result=events)
|
|
@ -33,10 +33,13 @@ from werkzeug.wsgi import FileWrapper
|
|||
from superset import is_feature_enabled, thumbnail_cache
|
||||
from superset.charts.commands.bulk_delete import BulkDeleteChartCommand
|
||||
from superset.charts.commands.create import CreateChartCommand
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.charts.commands.delete import DeleteChartCommand
|
||||
from superset.charts.commands.exceptions import (
|
||||
ChartBulkDeleteFailedError,
|
||||
ChartCreateFailedError,
|
||||
ChartDataCacheLoadError,
|
||||
ChartDataQueryFailedError,
|
||||
ChartDeleteFailedError,
|
||||
ChartForbiddenError,
|
||||
ChartInvalidError,
|
||||
|
@ -50,7 +53,6 @@ from superset.charts.dao import ChartDAO
|
|||
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
|
||||
from superset.charts.schemas import (
|
||||
CHART_SCHEMAS,
|
||||
ChartDataQueryContextSchema,
|
||||
ChartPostSchema,
|
||||
ChartPutSchema,
|
||||
get_delete_ids_schema,
|
||||
|
@ -67,7 +69,12 @@ from superset.exceptions import SupersetSecurityException
|
|||
from superset.extensions import event_logger
|
||||
from superset.models.slice import Slice
|
||||
from superset.tasks.thumbnails import cache_chart_thumbnail
|
||||
from superset.utils.core import ChartDataResultFormat, json_int_dttm_ser
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.core import (
|
||||
ChartDataResultFormat,
|
||||
ChartDataResultType,
|
||||
json_int_dttm_ser,
|
||||
)
|
||||
from superset.utils.screenshots import ChartScreenshot
|
||||
from superset.utils.urls import get_url_path
|
||||
from superset.views.base_api import (
|
||||
|
@ -93,6 +100,8 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
RouteMethod.RELATED,
|
||||
"bulk_delete", # not using RouteMethod since locally defined
|
||||
"data",
|
||||
"data_from_cache",
|
||||
"viz_types",
|
||||
"favorite_status",
|
||||
}
|
||||
class_permission_name = "SliceModelView"
|
||||
|
@ -448,6 +457,39 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
except ChartBulkDeleteFailedError as ex:
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
def get_data_response(
|
||||
self, command: ChartDataCommand, force_cached: bool = False
|
||||
) -> Response:
|
||||
try:
|
||||
result = command.run(force_cached=force_cached)
|
||||
except ChartDataCacheLoadError as exc:
|
||||
return self.response_422(message=exc.message)
|
||||
except ChartDataQueryFailedError as exc:
|
||||
return self.response_400(message=exc.message)
|
||||
|
||||
result_format = result["query_context"].result_format
|
||||
if result_format == ChartDataResultFormat.CSV:
|
||||
# return the first result
|
||||
data = result["queries"][0]["data"]
|
||||
return CsvResponse(
|
||||
data,
|
||||
status=200,
|
||||
headers=generate_download_headers("csv"),
|
||||
mimetype="application/csv",
|
||||
)
|
||||
|
||||
if result_format == ChartDataResultFormat.JSON:
|
||||
response_data = simplejson.dumps(
|
||||
{"result": result["queries"]},
|
||||
default=json_int_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
resp = make_response(response_data, 200)
|
||||
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
return resp
|
||||
|
||||
return self.response_400(message=f"Unsupported result_format: {result_format}")
|
||||
|
||||
@expose("/data", methods=["POST"])
|
||||
@protect()
|
||||
@safe
|
||||
|
@ -478,8 +520,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ChartDataResponseSchema"
|
||||
202:
|
||||
description: Async job details
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
|
@ -490,47 +540,88 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
json_body = json.loads(request.form["form_data"])
|
||||
else:
|
||||
return self.response_400(message="Request is not JSON")
|
||||
|
||||
try:
|
||||
query_context = ChartDataQueryContextSchema().load(json_body)
|
||||
except KeyError:
|
||||
return self.response_400(message="Request is incorrect")
|
||||
command = ChartDataCommand()
|
||||
query_context = command.set_query_context(json_body)
|
||||
command.validate()
|
||||
except ValidationError as error:
|
||||
return self.response_400(
|
||||
message=_("Request is incorrect: %(error)s", error=error.messages)
|
||||
)
|
||||
try:
|
||||
query_context.raise_for_access()
|
||||
except SupersetSecurityException:
|
||||
return self.response_401()
|
||||
payload = query_context.get_payload()
|
||||
for query in payload:
|
||||
if query.get("error"):
|
||||
return self.response_400(message=f"Error: {query['error']}")
|
||||
result_format = query_context.result_format
|
||||
|
||||
response = self.response_400(
|
||||
message=f"Unsupported result_format: {result_format}"
|
||||
)
|
||||
# TODO: support CSV, SQL query and other non-JSON types
|
||||
if (
|
||||
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
||||
and query_context.result_format == ChartDataResultFormat.JSON
|
||||
and query_context.result_type == ChartDataResultType.FULL
|
||||
):
|
||||
|
||||
if result_format == ChartDataResultFormat.CSV:
|
||||
# return the first result
|
||||
result = payload[0]["data"]
|
||||
response = CsvResponse(
|
||||
result,
|
||||
status=200,
|
||||
headers=generate_download_headers("csv"),
|
||||
mimetype="application/csv",
|
||||
try:
|
||||
command.validate_async_request(request)
|
||||
except AsyncQueryTokenException:
|
||||
return self.response_401()
|
||||
|
||||
result = command.run_async()
|
||||
return self.response(202, **result)
|
||||
|
||||
return self.get_data_response(command)
|
||||
|
||||
@expose("/data/<cache_key>", methods=["GET"])
|
||||
@event_logger.log_this
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def data_from_cache(self, cache_key: str) -> Response:
|
||||
"""
|
||||
Takes a query context cache key and returns payload
|
||||
data response for the given query.
|
||||
---
|
||||
get:
|
||||
description: >-
|
||||
Takes a query context cache key and returns payload data
|
||||
response for the given query.
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: cache_key
|
||||
responses:
|
||||
200:
|
||||
description: Query result
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ChartDataResponseSchema"
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
command = ChartDataCommand()
|
||||
try:
|
||||
cached_data = command.load_query_context_from_cache(cache_key)
|
||||
command.set_query_context(cached_data)
|
||||
command.validate()
|
||||
except ChartDataCacheLoadError:
|
||||
return self.response_404()
|
||||
except ValidationError as error:
|
||||
return self.response_400(
|
||||
message=_("Request is incorrect: %(error)s", error=error.messages)
|
||||
)
|
||||
except SupersetSecurityException as exc:
|
||||
logger.info(exc)
|
||||
return self.response_401()
|
||||
|
||||
if result_format == ChartDataResultFormat.JSON:
|
||||
response_data = simplejson.dumps(
|
||||
{"result": payload}, default=json_int_dttm_ser, ignore_nan=True
|
||||
)
|
||||
resp = make_response(response_data, 200)
|
||||
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
response = resp
|
||||
|
||||
return response
|
||||
return self.get_data_response(command, True)
|
||||
|
||||
@expose("/<pk>/cache_screenshot/", methods=["GET"])
|
||||
@protect()
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from flask import Request
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import cache
|
||||
from superset.charts.commands.exceptions import (
|
||||
ChartDataCacheLoadError,
|
||||
ChartDataQueryFailedError,
|
||||
)
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.exceptions import CacheLoadError
|
||||
from superset.extensions import async_query_manager
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChartDataCommand(BaseCommand):
|
||||
def __init__(self) -> None:
|
||||
self._form_data: Dict[str, Any]
|
||||
self._query_context: QueryContext
|
||||
self._async_channel_id: str
|
||||
|
||||
def run(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# caching is handled in query_context.get_df_payload
|
||||
# (also evals `force` property)
|
||||
cache_query_context = kwargs.get("cache", False)
|
||||
force_cached = kwargs.get("force_cached", False)
|
||||
try:
|
||||
payload = self._query_context.get_payload(
|
||||
cache_query_context=cache_query_context, force_cached=force_cached
|
||||
)
|
||||
except CacheLoadError as exc:
|
||||
raise ChartDataCacheLoadError(exc.message)
|
||||
|
||||
# TODO: QueryContext should support SIP-40 style errors
|
||||
for query in payload["queries"]:
|
||||
if query.get("error"):
|
||||
raise ChartDataQueryFailedError(f"Error: {query['error']}")
|
||||
|
||||
return_value = {
|
||||
"query_context": self._query_context,
|
||||
"queries": payload["queries"],
|
||||
}
|
||||
if cache_query_context:
|
||||
return_value.update(cache_key=payload["cache_key"])
|
||||
|
||||
return return_value
|
||||
|
||||
def run_async(self) -> Dict[str, Any]:
|
||||
job_metadata = async_query_manager.init_job(self._async_channel_id)
|
||||
load_chart_data_into_cache.delay(job_metadata, self._form_data)
|
||||
|
||||
return job_metadata
|
||||
|
||||
def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
|
||||
self._form_data = form_data
|
||||
try:
|
||||
self._query_context = ChartDataQueryContextSchema().load(self._form_data)
|
||||
except KeyError:
|
||||
raise ValidationError("Request is incorrect")
|
||||
except ValidationError as error:
|
||||
raise error
|
||||
|
||||
return self._query_context
|
||||
|
||||
def validate(self) -> None:
|
||||
self._query_context.raise_for_access()
|
||||
|
||||
def validate_async_request(self, request: Request) -> None:
|
||||
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
||||
self._async_channel_id = jwt_data["channel"]
|
||||
|
||||
def load_query_context_from_cache( # pylint: disable=no-self-use
|
||||
self, cache_key: str
|
||||
) -> Dict[str, Any]:
|
||||
cache_value = cache.get(cache_key)
|
||||
if not cache_value:
|
||||
raise ChartDataCacheLoadError("Cached data not found")
|
||||
|
||||
return cache_value["data"]
|
|
@ -90,6 +90,14 @@ class ChartBulkDeleteFailedError(DeleteFailedError):
|
|||
message = _("Charts could not be deleted.")
|
||||
|
||||
|
||||
class ChartDataQueryFailedError(CommandException):
|
||||
pass
|
||||
|
||||
|
||||
class ChartDataCacheLoadError(CommandException):
|
||||
pass
|
||||
|
||||
|
||||
class ChartBulkDeleteFailedReportsExistError(ChartBulkDeleteFailedError):
|
||||
message = _("There are associated alerts or reports")
|
||||
|
||||
|
|
|
@ -1105,6 +1105,18 @@ class ChartDataResponseSchema(Schema):
|
|||
)
|
||||
|
||||
|
||||
class ChartDataAsyncResponseSchema(Schema):
|
||||
channel_id = fields.String(
|
||||
description="Unique session async channel ID", allow_none=False,
|
||||
)
|
||||
job_id = fields.String(description="Unique async job ID", allow_none=False,)
|
||||
user_id = fields.String(description="Requesting user ID", allow_none=True,)
|
||||
status = fields.String(description="Status value for async job", allow_none=False,)
|
||||
result_url = fields.String(
|
||||
description="Unique result URL for fetching async query data", allow_none=False,
|
||||
)
|
||||
|
||||
|
||||
class ChartFavStarResponseResult(Schema):
|
||||
id = fields.Integer(description="The Chart id")
|
||||
value = fields.Boolean(description="The FaveStar value")
|
||||
|
@ -1130,6 +1142,7 @@ class ImportV1ChartSchema(Schema):
|
|||
CHART_SCHEMAS = (
|
||||
ChartDataQueryContextSchema,
|
||||
ChartDataResponseSchema,
|
||||
ChartDataAsyncResponseSchema,
|
||||
# TODO: These should optimally be included in the QueryContext schema as an `anyOf`
|
||||
# in ChartDataPostPricessingOperation.options, but since `anyOf` is not
|
||||
# by Marshmallow<3, this is not currently possible.
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import copy
|
||||
import logging
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -30,13 +30,17 @@ from superset.charts.dao import ChartDAO
|
|||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.exceptions import QueryObjectValidationError, SupersetException
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
QueryObjectValidationError,
|
||||
SupersetException,
|
||||
)
|
||||
from superset.extensions import cache_manager, security_manager
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.views.utils import get_viz
|
||||
from superset.viz import set_and_log_cache
|
||||
|
||||
config = app.config
|
||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
|
@ -78,6 +82,13 @@ class QueryContext:
|
|||
self.custom_cache_timeout = custom_cache_timeout
|
||||
self.result_type = result_type or utils.ChartDataResultType.FULL
|
||||
self.result_format = result_format or utils.ChartDataResultFormat.JSON
|
||||
self.cache_values = {
|
||||
"datasource": datasource,
|
||||
"queries": queries,
|
||||
"force": force,
|
||||
"result_type": result_type,
|
||||
"result_format": result_format,
|
||||
}
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
|
||||
"""Returns a pandas dataframe based on the query object"""
|
||||
|
@ -142,8 +153,11 @@ class QueryContext:
|
|||
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
|
||||
def get_single_payload(
|
||||
self, query_obj: QueryObject, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns a payload of metadata and data"""
|
||||
force_cached = kwargs.get("force_cached", False)
|
||||
if self.result_type == utils.ChartDataResultType.QUERY:
|
||||
return {
|
||||
"query": self.datasource.get_query_str(query_obj.to_dict()),
|
||||
|
@ -159,8 +173,7 @@ class QueryContext:
|
|||
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
|
||||
query_obj.row_offset = 0
|
||||
query_obj.columns = [o.column_name for o in self.datasource.columns]
|
||||
payload = self.get_df_payload(query_obj)
|
||||
|
||||
payload = self.get_df_payload(query_obj, force_cached=force_cached)
|
||||
df = payload["df"]
|
||||
status = payload["status"]
|
||||
if status != utils.QueryStatus.FAILED:
|
||||
|
@ -186,9 +199,28 @@ class QueryContext:
|
|||
return {"data": payload["data"]}
|
||||
return payload
|
||||
|
||||
def get_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Get all the payloads from the QueryObjects"""
|
||||
return [self.get_single_payload(query_object) for query_object in self.queries]
|
||||
def get_payload(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
cache_query_context = kwargs.get("cache_query_context", False)
|
||||
force_cached = kwargs.get("force_cached", False)
|
||||
|
||||
# Get all the payloads from the QueryObjects
|
||||
query_results = [
|
||||
self.get_single_payload(query_object, force_cached=force_cached)
|
||||
for query_object in self.queries
|
||||
]
|
||||
return_value = {"queries": query_results}
|
||||
|
||||
if cache_query_context:
|
||||
cache_key = self.cache_key()
|
||||
set_and_log_cache(
|
||||
cache_manager.cache,
|
||||
cache_key,
|
||||
{"data": self.cache_values},
|
||||
self.cache_timeout,
|
||||
)
|
||||
return_value["cache_key"] = cache_key # type: ignore
|
||||
|
||||
return return_value
|
||||
|
||||
@property
|
||||
def cache_timeout(self) -> int:
|
||||
|
@ -203,7 +235,22 @@ class QueryContext:
|
|||
return self.datasource.database.cache_timeout
|
||||
return config["CACHE_DEFAULT_TIMEOUT"]
|
||||
|
||||
def cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
||||
def cache_key(self, **extra: Any) -> str:
|
||||
"""
|
||||
The QueryContext cache key is made out of the key/values from
|
||||
self.cached_values, plus any other key/values in `extra`. It includes only data
|
||||
required to rehydrate a QueryContext object.
|
||||
"""
|
||||
key_prefix = "qc-"
|
||||
cache_dict = self.cache_values.copy()
|
||||
cache_dict.update(extra)
|
||||
|
||||
return generate_cache_key(cache_dict, key_prefix)
|
||||
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
||||
"""
|
||||
Returns a QueryObject cache key for objects in self.queries
|
||||
"""
|
||||
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
|
||||
|
||||
cache_key = (
|
||||
|
@ -215,7 +262,7 @@ class QueryContext:
|
|||
and self.datasource.is_rls_supported
|
||||
else [],
|
||||
changed_on=self.datasource.changed_on,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
if query_obj
|
||||
else None
|
||||
|
@ -298,12 +345,12 @@ class QueryContext:
|
|||
self, query_obj: QueryObject, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Handles caching around the df payload retrieval"""
|
||||
cache_key = self.cache_key(query_obj, **kwargs)
|
||||
force_cached = kwargs.get("force_cached", False)
|
||||
cache_key = self.query_cache_key(query_obj)
|
||||
logger.info("Cache key: %s", cache_key)
|
||||
is_loaded = False
|
||||
stacktrace = None
|
||||
df = pd.DataFrame()
|
||||
cached_dttm = datetime.utcnow().isoformat().split(".")[0]
|
||||
cache_value = None
|
||||
status = None
|
||||
query = ""
|
||||
|
@ -327,6 +374,12 @@ class QueryContext:
|
|||
)
|
||||
logger.info("Serving from cache")
|
||||
|
||||
if force_cached and not is_loaded:
|
||||
logger.warning(
|
||||
"force_cached (QueryContext): value not found for key %s", cache_key
|
||||
)
|
||||
raise CacheLoadError("Error loading data from cache")
|
||||
|
||||
if query_obj and not is_loaded:
|
||||
try:
|
||||
invalid_columns = [
|
||||
|
@ -367,13 +420,11 @@ class QueryContext:
|
|||
|
||||
if is_loaded and cache_key and status != utils.QueryStatus.FAILED:
|
||||
set_and_log_cache(
|
||||
cache_key=cache_key,
|
||||
df=df,
|
||||
query=query,
|
||||
annotation_data=annotation_data,
|
||||
cached_dttm=cached_dttm,
|
||||
cache_timeout=self.cache_timeout,
|
||||
datasource_uid=self.datasource.uid,
|
||||
cache_manager.data_cache,
|
||||
cache_key,
|
||||
{"df": df, "query": query, "annotation_data": annotation_data},
|
||||
self.cache_timeout,
|
||||
self.datasource.uid,
|
||||
)
|
||||
return {
|
||||
"cache_key": cache_key,
|
||||
|
|
|
@ -277,7 +277,8 @@ class QueryObject:
|
|||
:param df: DataFrame returned from database model.
|
||||
:return: new DataFrame to which all post processing operations have been
|
||||
applied
|
||||
:raises ChartDataValidationError: If the post processing operation in incorrect
|
||||
:raises QueryObjectValidationError: If the post processing operation
|
||||
is incorrect
|
||||
"""
|
||||
for post_process in self.post_processing:
|
||||
operation = post_process.get("operation")
|
||||
|
|
|
@ -327,6 +327,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
|||
"DISPLAY_MARKDOWN_HTML": True,
|
||||
# When True, this escapes HTML (rather than rendering it) in Markdown components
|
||||
"ESCAPE_MARKDOWN_HTML": False,
|
||||
"GLOBAL_ASYNC_QUERIES": False,
|
||||
"VERSIONED_EXPORT": False,
|
||||
# Note that: RowLevelSecurityFilter is only given by default to the Admin role
|
||||
# and the Admin Role does have the all_datasources security permission.
|
||||
|
@ -406,6 +407,9 @@ CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
|
|||
# Cache for datasource metadata and query results
|
||||
DATA_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"}
|
||||
|
||||
# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation
|
||||
STORE_CACHE_KEYS_IN_METADATA_DB = False
|
||||
|
||||
# CORS Options
|
||||
ENABLE_CORS = False
|
||||
CORS_OPTIONS: Dict[Any, Any] = {}
|
||||
|
@ -965,6 +969,23 @@ SIP_15_TOAST_MESSAGE = (
|
|||
# conventions and such. You can find examples in the tests.
|
||||
SQLA_TABLE_MUTATOR = lambda table: table
|
||||
|
||||
# Global async query config options.
|
||||
# Requires GLOBAL_ASYNC_QUERIES feature flag to be enabled.
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = {
|
||||
"port": 6379,
|
||||
"host": "127.0.0.1",
|
||||
"password": "",
|
||||
"db": 0,
|
||||
}
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
|
||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
|
||||
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
|
||||
GLOBAL_ASYNC_QUERIES_POLLING_DELAY = 500
|
||||
|
||||
if CONFIG_PATH_ENV_VAR in os.environ:
|
||||
# Explicitly import config module that is not necessarily in pythonpath; useful
|
||||
# for case where app is being executed via pex.
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_babel import gettext as _
|
||||
|
||||
|
@ -65,6 +65,14 @@ class SupersetSecurityException(SupersetException):
|
|||
self.payload = payload
|
||||
|
||||
|
||||
class SupersetVizException(SupersetException):
|
||||
status = 400
|
||||
|
||||
def __init__(self, errors: List[SupersetError]) -> None:
|
||||
super(SupersetVizException, self).__init__(str(errors))
|
||||
self.errors = errors
|
||||
|
||||
|
||||
class NoDataException(SupersetException):
|
||||
status = 400
|
||||
|
||||
|
@ -93,6 +101,10 @@ class QueryObjectValidationError(SupersetException):
|
|||
status = 400
|
||||
|
||||
|
||||
class CacheLoadError(SupersetException):
|
||||
status = 404
|
||||
|
||||
|
||||
class DashboardImportException(SupersetException):
|
||||
pass
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ from flask_talisman import Talisman
|
|||
from flask_wtf.csrf import CSRFProtect
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.utils.async_query_manager import AsyncQueryManager
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
from superset.utils.feature_flag_manager import FeatureFlagManager
|
||||
from superset.utils.machine_auth import MachineAuthProviderFactory
|
||||
|
@ -97,6 +98,7 @@ class UIManifestProcessor:
|
|||
|
||||
APP_DIR = os.path.dirname(__file__)
|
||||
appbuilder = AppBuilder(update_perms=False)
|
||||
async_query_manager = AsyncQueryManager()
|
||||
cache_manager = CacheManager()
|
||||
celery_app = celery.Celery()
|
||||
csrf = CSRFProtect()
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, cast, Dict, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from superset import app
|
||||
from superset.exceptions import SupersetVizException
|
||||
from superset.extensions import async_query_manager, cache_manager, celery_app
|
||||
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
||||
from superset.views.utils import get_datasource_info, get_viz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
query_timeout = current_app.config[
|
||||
"SQLLAB_ASYNC_TIME_LIMIT_SEC"
|
||||
] # TODO: new config key
|
||||
|
||||
|
||||
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
|
||||
def load_chart_data_into_cache(
|
||||
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
|
||||
) -> None:
|
||||
from superset.charts.commands.data import (
|
||||
ChartDataCommand,
|
||||
) # load here due to circular imports
|
||||
|
||||
with app.app_context(): # type: ignore
|
||||
try:
|
||||
command = ChartDataCommand()
|
||||
command.set_query_context(form_data)
|
||||
result = command.run(cache=True)
|
||||
cache_key = result["cache_key"]
|
||||
result_url = f"/api/v1/chart/data/{cache_key}"
|
||||
async_query_manager.update_job(
|
||||
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
|
||||
)
|
||||
except Exception as exc:
|
||||
# TODO: QueryContext should support SIP-40 style errors
|
||||
error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
|
||||
errors = [{"message": error}]
|
||||
async_query_manager.update_job(
|
||||
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
||||
)
|
||||
raise exc
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
|
||||
def load_explore_json_into_cache(
|
||||
job_metadata: Dict[str, Any],
|
||||
form_data: Dict[str, Any],
|
||||
response_type: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
with app.app_context(): # type: ignore
|
||||
cache_key_prefix = "ejr-" # ejr: explore_json request
|
||||
try:
|
||||
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
||||
|
||||
viz_obj = get_viz(
|
||||
datasource_type=cast(str, datasource_type),
|
||||
datasource_id=datasource_id,
|
||||
form_data=form_data,
|
||||
force=force,
|
||||
)
|
||||
# run query & cache results
|
||||
payload = viz_obj.get_payload()
|
||||
if viz_obj.has_error(payload):
|
||||
raise SupersetVizException(errors=payload["errors"])
|
||||
|
||||
# cache form_data for async retrieval
|
||||
cache_value = {"form_data": form_data, "response_type": response_type}
|
||||
cache_key = generate_cache_key(cache_value, cache_key_prefix)
|
||||
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
|
||||
result_url = f"/superset/explore_json/data/{cache_key}"
|
||||
async_query_manager.update_job(
|
||||
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
|
||||
)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, SupersetVizException):
|
||||
errors = exc.errors # pylint: disable=no-member
|
||||
else:
|
||||
error = (
|
||||
exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
|
||||
)
|
||||
errors = [error]
|
||||
|
||||
async_query_manager.update_job(
|
||||
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
|
||||
)
|
||||
raise exc
|
||||
|
||||
return None
|
|
@ -0,0 +1,199 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import jwt
|
||||
import redis
|
||||
from flask import Flask, Request, Response, session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncQueryTokenException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncQueryJobException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"channel_id": channel_id,
|
||||
"job_id": job_id,
|
||||
"user_id": session.get("user_id"),
|
||||
"status": kwargs.get("status"),
|
||||
"errors": kwargs.get("errors", []),
|
||||
"result_url": kwargs.get("result_url"),
|
||||
}
|
||||
|
||||
|
||||
def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
event_id = event_data[0]
|
||||
event_payload = event_data[1]["data"]
|
||||
return {"id": event_id, **json.loads(event_payload)}
|
||||
|
||||
|
||||
def increment_id(redis_id: str) -> str:
|
||||
# redis stream IDs are in this format: '1607477697866-0'
|
||||
try:
|
||||
prefix, last = redis_id[:-1], int(redis_id[-1])
|
||||
return prefix + str(last + 1)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return redis_id
|
||||
|
||||
|
||||
class AsyncQueryManager:
|
||||
MAX_EVENT_COUNT = 100
|
||||
STATUS_PENDING = "pending"
|
||||
STATUS_RUNNING = "running"
|
||||
STATUS_ERROR = "error"
|
||||
STATUS_DONE = "done"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._redis: redis.Redis
|
||||
self._stream_prefix: str = ""
|
||||
self._stream_limit: Optional[int]
|
||||
self._stream_limit_firehose: Optional[int]
|
||||
self._jwt_cookie_name: str
|
||||
self._jwt_cookie_secure: bool = False
|
||||
self._jwt_secret: str
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
if (
|
||||
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
|
||||
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
|
||||
):
|
||||
raise Exception(
|
||||
"""
|
||||
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
|
||||
and non-null in order to enable async queries
|
||||
"""
|
||||
)
|
||||
|
||||
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
|
||||
raise AsyncQueryTokenException(
|
||||
"Please provide a JWT secret at least 32 bytes long"
|
||||
)
|
||||
|
||||
self._redis = redis.Redis( # type: ignore
|
||||
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
|
||||
)
|
||||
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
|
||||
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
|
||||
self._stream_limit_firehose = config[
|
||||
"GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE"
|
||||
]
|
||||
self._jwt_cookie_name = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"]
|
||||
self._jwt_cookie_secure = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE"]
|
||||
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
|
||||
|
||||
@app.after_request
|
||||
def validate_session( # pylint: disable=unused-variable
|
||||
response: Response,
|
||||
) -> Response:
|
||||
reset_token = False
|
||||
user_id = session["user_id"] if "user_id" in session else None
|
||||
|
||||
if "async_channel_id" not in session or "async_user_id" not in session:
|
||||
reset_token = True
|
||||
elif user_id != session["async_user_id"]:
|
||||
reset_token = True
|
||||
|
||||
if reset_token:
|
||||
async_channel_id = str(uuid.uuid4())
|
||||
session["async_channel_id"] = async_channel_id
|
||||
session["async_user_id"] = user_id
|
||||
|
||||
sub = str(user_id) if user_id else None
|
||||
token = self.generate_jwt({"channel": async_channel_id, "sub": sub})
|
||||
|
||||
response.set_cookie(
|
||||
self._jwt_cookie_name,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=self._jwt_cookie_secure,
|
||||
# max_age=max_age or config.cookie_max_age,
|
||||
# domain=config.cookie_domain,
|
||||
# path=config.access_cookie_path,
|
||||
# samesite=config.cookie_samesite
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def generate_jwt(self, data: Dict[str, Any]) -> str:
|
||||
encoded_jwt = jwt.encode(data, self._jwt_secret, algorithm="HS256")
|
||||
return encoded_jwt.decode("utf-8")
|
||||
|
||||
def parse_jwt(self, token: str) -> Dict[str, Any]:
|
||||
data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
|
||||
return data
|
||||
|
||||
def parse_jwt_from_request(self, request: Request) -> Dict[str, Any]:
|
||||
token = request.cookies.get(self._jwt_cookie_name)
|
||||
if not token:
|
||||
raise AsyncQueryTokenException("Token not preset")
|
||||
|
||||
try:
|
||||
return self.parse_jwt(token)
|
||||
except Exception as exc:
|
||||
logger.warning(exc)
|
||||
raise AsyncQueryTokenException("Failed to parse token")
|
||||
|
||||
def init_job(self, channel_id: str) -> Dict[str, Any]:
|
||||
job_id = str(uuid.uuid4())
|
||||
return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING)
|
||||
|
||||
def read_events(
|
||||
self, channel: str, last_id: Optional[str]
|
||||
) -> List[Optional[Dict[str, Any]]]:
|
||||
stream_name = f"{self._stream_prefix}{channel}"
|
||||
start_id = increment_id(last_id) if last_id else "-"
|
||||
results = self._redis.xrange( # type: ignore
|
||||
stream_name, start_id, "+", self.MAX_EVENT_COUNT
|
||||
)
|
||||
return [] if not results else list(map(parse_event, results))
|
||||
|
||||
def update_job(
|
||||
self, job_metadata: Dict[str, Any], status: str, **kwargs: Any
|
||||
) -> None:
|
||||
if "channel_id" not in job_metadata:
|
||||
raise AsyncQueryJobException("No channel ID specified")
|
||||
|
||||
if "job_id" not in job_metadata:
|
||||
raise AsyncQueryJobException("No job ID specified")
|
||||
|
||||
updates = {"status": status, **kwargs}
|
||||
event_data = {"data": json.dumps({**job_metadata, **updates})}
|
||||
|
||||
full_stream_name = f"{self._stream_prefix}full"
|
||||
scoped_stream_name = f"{self._stream_prefix}{job_metadata['channel_id']}"
|
||||
|
||||
logger.debug("********** logging event data to stream %s", scoped_stream_name)
|
||||
logger.debug(event_data)
|
||||
|
||||
self._redis.xadd( # type: ignore
|
||||
scoped_stream_name, event_data, "*", self._stream_limit
|
||||
)
|
||||
self._redis.xadd( # type: ignore
|
||||
full_stream_name, event_data, "*", self._stream_limit_firehose
|
||||
)
|
|
@ -14,16 +14,65 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from flask import current_app as app, request
|
||||
from flask_caching import Cache
|
||||
from werkzeug.wrappers.etag import ETagResponseMixin
|
||||
|
||||
from superset import db
|
||||
from superset.extensions import cache_manager
|
||||
from superset.models.cache import CacheKey
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.utils.core import json_int_dttm_ser
|
||||
|
||||
config = app.config # type: ignore
|
||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: DRY up cache key code
|
||||
def json_dumps(obj: Any, sort_keys: bool = False) -> str:
|
||||
return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys)
|
||||
|
||||
|
||||
def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str:
|
||||
json_data = json_dumps(values_dict, sort_keys=True)
|
||||
hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest()
|
||||
return f"{key_prefix}{hash_str}"
|
||||
|
||||
|
||||
def set_and_log_cache(
|
||||
cache_instance: Cache,
|
||||
cache_key: str,
|
||||
cache_value: Dict[str, Any],
|
||||
cache_timeout: Optional[int] = None,
|
||||
datasource_uid: Optional[str] = None,
|
||||
) -> None:
|
||||
timeout = cache_timeout if cache_timeout else config["CACHE_DEFAULT_TIMEOUT"]
|
||||
try:
|
||||
dttm = datetime.utcnow().isoformat().split(".")[0]
|
||||
value = {**cache_value, "dttm": dttm}
|
||||
cache_instance.set(cache_key, value, timeout=timeout)
|
||||
stats_logger.incr("set_cache_key")
|
||||
|
||||
if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
|
||||
ck = CacheKey(
|
||||
cache_key=cache_key,
|
||||
cache_timeout=cache_timeout,
|
||||
datasource_uid=datasource_uid,
|
||||
)
|
||||
db.session.add(ck)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
# cache.set call can fail if the backend is down or if
|
||||
# the key is too large or whatever other reasons
|
||||
logger.warning("Could not cache key %s", cache_key)
|
||||
logger.exception(ex)
|
||||
|
||||
|
||||
# If a user sets `max_age` to 0, for long the browser should cache the
|
||||
# resource? Flask-Caching will cache forever, but for the HTTP header we need
|
||||
|
|
|
@ -238,7 +238,7 @@ def pivot( # pylint: disable=too-many-arguments
|
|||
Default to 'All'.
|
||||
:param flatten_columns: Convert column names to strings
|
||||
:return: A pivot table
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
if not index:
|
||||
raise QueryObjectValidationError(
|
||||
|
@ -293,7 +293,7 @@ def aggregate(
|
|||
:param groupby: columns to aggregate
|
||||
:param aggregates: A mapping from metric column to the function used to
|
||||
aggregate values.
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
aggregates = aggregates or {}
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
|
@ -313,7 +313,7 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
|||
:param columns: columns by by which to sort. The key specifies the column name,
|
||||
value specifies if sorting in ascending order.
|
||||
:return: Sorted DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
|
||||
|
||||
|
@ -348,7 +348,7 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
:param min_periods: The minimum amount of periods required for a row to be included
|
||||
in the result set.
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
df_rolling = df[columns.keys()]
|
||||
|
@ -408,7 +408,7 @@ def select(
|
|||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
df_select = df.copy(deep=False)
|
||||
if columns:
|
||||
|
@ -433,7 +433,7 @@ def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame
|
|||
unchanged.
|
||||
:param periods: periods to shift for calculating difference.
|
||||
:return: DataFrame with diffed columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
df_diff = df[columns.keys()]
|
||||
df_diff = df_diff.diff(periods=periods)
|
||||
|
|
|
@ -44,7 +44,8 @@ class Api(BaseSupersetView):
|
|||
"""
|
||||
query_context = QueryContext(**json.loads(request.form["query_context"]))
|
||||
query_context.raise_for_access()
|
||||
payload_json = query_context.get_payload()
|
||||
result = query_context.get_payload()
|
||||
payload_json = result["queries"]
|
||||
return json.dumps(
|
||||
payload_json, default=utils.json_int_dttm_ser, ignore_nan=True
|
||||
)
|
||||
|
|
|
@ -77,6 +77,8 @@ FRONTEND_CONF_KEYS = (
|
|||
"SUPERSET_WEBSERVER_DOMAINS",
|
||||
"SQLLAB_SAVE_WARNING_MESSAGE",
|
||||
"DISPLAY_MAX_ROW",
|
||||
"GLOBAL_ASYNC_QUERIES_TRANSPORT",
|
||||
"GLOBAL_ASYNC_QUERIES_POLLING_DELAY",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -126,6 +126,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
method_permission_name = {
|
||||
"bulk_delete": "delete",
|
||||
"data": "list",
|
||||
"data_from_cache": "list",
|
||||
"delete": "delete",
|
||||
"distinct": "list",
|
||||
"export": "mulexport",
|
||||
|
|
|
@ -28,7 +28,11 @@ import simplejson as json
|
|||
from flask import abort, flash, g, Markup, redirect, render_template, request, Response
|
||||
from flask_appbuilder import expose
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_appbuilder.security.decorators import has_access, has_access_api
|
||||
from flask_appbuilder.security.decorators import (
|
||||
has_access,
|
||||
has_access_api,
|
||||
permission_name,
|
||||
)
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from jinja2.exceptions import TemplateError
|
||||
|
@ -66,6 +70,7 @@ from superset.dashboards.dao import DashboardDAO
|
|||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
CertificateException,
|
||||
DatabaseNotFound,
|
||||
SerializationError,
|
||||
|
@ -73,6 +78,7 @@ from superset.exceptions import (
|
|||
SupersetSecurityException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.extensions import async_query_manager, cache_manager
|
||||
from superset.jinja_context import get_template_processor
|
||||
from superset.models.core import Database, FavStar, Log
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -87,8 +93,10 @@ from superset.security.analytics_db_safety import (
|
|||
)
|
||||
from superset.sql_parse import CtasMethod, ParsedQuery, Table
|
||||
from superset.sql_validators import get_validator_by_name
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
from superset.typing import FlaskResponse
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.cache import etag_cache
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.views.base import (
|
||||
|
@ -113,6 +121,7 @@ from superset.views.utils import (
|
|||
apply_display_max_row_limit,
|
||||
bootstrap_user_data,
|
||||
check_datasource_perms,
|
||||
check_explore_cache_perms,
|
||||
check_slice_perms,
|
||||
get_cta_schema_name,
|
||||
get_dashboard_extra_filters,
|
||||
|
@ -484,6 +493,43 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
payload = viz_obj.get_payload()
|
||||
return data_payload_response(*viz_obj.payload_json_and_has_error(payload))
|
||||
|
||||
@event_logger.log_this
|
||||
@api
|
||||
@has_access_api
|
||||
@handle_api_exception
|
||||
@permission_name("explore_json")
|
||||
@expose("/explore_json/data/<cache_key>", methods=["GET"])
|
||||
@etag_cache(check_perms=check_explore_cache_perms)
|
||||
def explore_json_data(self, cache_key: str) -> FlaskResponse:
|
||||
"""Serves cached result data for async explore_json calls
|
||||
|
||||
`self.generate_json` receives this input and returns different
|
||||
payloads based on the request args in the first block
|
||||
|
||||
TODO: form_data should not be loaded twice from cache
|
||||
(also loaded in `check_explore_cache_perms`)
|
||||
"""
|
||||
try:
|
||||
cached = cache_manager.cache.get(cache_key)
|
||||
if not cached:
|
||||
raise CacheLoadError("Cached data not found")
|
||||
|
||||
form_data = cached.get("form_data")
|
||||
response_type = cached.get("response_type")
|
||||
|
||||
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
||||
|
||||
viz_obj = get_viz(
|
||||
datasource_type=cast(str, datasource_type),
|
||||
datasource_id=datasource_id,
|
||||
form_data=form_data,
|
||||
force_cached=True,
|
||||
)
|
||||
|
||||
return self.generate_json(viz_obj, response_type)
|
||||
except SupersetException as ex:
|
||||
return json_error_response(utils.error_msg_from_exception(ex), 400)
|
||||
|
||||
EXPLORE_JSON_METHODS = ["POST"]
|
||||
if not is_feature_enabled("ENABLE_EXPLORE_JSON_CSRF_PROTECTION"):
|
||||
EXPLORE_JSON_METHODS.append("GET")
|
||||
|
@ -528,11 +574,31 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
datasource_id, datasource_type, form_data
|
||||
)
|
||||
|
||||
force = request.args.get("force") == "true"
|
||||
|
||||
# TODO: support CSV, SQL query and other non-JSON types
|
||||
if (
|
||||
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
||||
and response_type == utils.ChartDataResultFormat.JSON
|
||||
):
|
||||
try:
|
||||
async_channel_id = async_query_manager.parse_jwt_from_request(
|
||||
request
|
||||
)["channel"]
|
||||
job_metadata = async_query_manager.init_job(async_channel_id)
|
||||
load_explore_json_into_cache.delay(
|
||||
job_metadata, form_data, response_type, force
|
||||
)
|
||||
except AsyncQueryTokenException:
|
||||
return json_error_response("Not authorized", 401)
|
||||
|
||||
return json_success(json.dumps(job_metadata), status=202)
|
||||
|
||||
viz_obj = get_viz(
|
||||
datasource_type=cast(str, datasource_type),
|
||||
datasource_id=datasource_id,
|
||||
form_data=form_data,
|
||||
force=request.args.get("force") == "true",
|
||||
force=force,
|
||||
)
|
||||
|
||||
return self.generate_json(viz_obj, response_type)
|
||||
|
|
|
@ -34,10 +34,12 @@ from superset import app, dataframe, db, is_feature_enabled, result_set
|
|||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
SerializationError,
|
||||
SupersetException,
|
||||
SupersetSecurityException,
|
||||
)
|
||||
from superset.extensions import cache_manager
|
||||
from superset.legacy import update_time_range
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -108,13 +110,19 @@ def get_permissions(
|
|||
|
||||
|
||||
def get_viz(
|
||||
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
|
||||
form_data: FormData,
|
||||
datasource_type: str,
|
||||
datasource_id: int,
|
||||
force: bool = False,
|
||||
force_cached: bool = False,
|
||||
) -> BaseViz:
|
||||
viz_type = form_data.get("viz_type", "table")
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
|
||||
viz_obj = viz.viz_types[viz_type](
|
||||
datasource, form_data=form_data, force=force, force_cached=force_cached
|
||||
)
|
||||
return viz_obj
|
||||
|
||||
|
||||
|
@ -422,10 +430,26 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
|
|||
return obj and user in obj.owners
|
||||
|
||||
|
||||
def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
|
||||
"""
|
||||
Loads async explore_json request data from cache and performs access check
|
||||
|
||||
:param _self: the Superset view instance
|
||||
:param cache_key: the cache key passed into /explore_json/data/
|
||||
:raises SupersetSecurityException: If the user cannot access the resource
|
||||
"""
|
||||
cached = cache_manager.cache.get(cache_key)
|
||||
if not cached:
|
||||
raise CacheLoadError("Cached data not found")
|
||||
|
||||
check_datasource_perms(_self, form_data=cached["form_data"])
|
||||
|
||||
|
||||
def check_datasource_perms(
|
||||
_self: Any,
|
||||
datasource_type: Optional[str] = None,
|
||||
datasource_id: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Check if user can access a cached response from explore_json.
|
||||
|
@ -438,7 +462,7 @@ def check_datasource_perms(
|
|||
:raises SupersetSecurityException: If the user cannot access the resource
|
||||
"""
|
||||
|
||||
form_data = get_form_data()[0]
|
||||
form_data = kwargs["form_data"] if "form_data" in kwargs else get_form_data()[0]
|
||||
|
||||
try:
|
||||
datasource_id, datasource_type = get_datasource_info(
|
||||
|
|
|
@ -38,6 +38,7 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
@ -57,6 +58,7 @@ from superset import app, db, is_feature_enabled
|
|||
from superset.constants import NULL_STRING
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
NullValueException,
|
||||
QueryObjectValidationError,
|
||||
SpatialException,
|
||||
|
@ -66,6 +68,7 @@ from superset.models.cache import CacheKey
|
|||
from superset.models.helpers import QueryResult
|
||||
from superset.typing import QueryObjectDict, VizData, VizPayload
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
DTTM_ALIAS,
|
||||
JS_MAX_INTEGER,
|
||||
|
@ -97,37 +100,6 @@ METRIC_KEYS = [
|
|||
]
|
||||
|
||||
|
||||
def set_and_log_cache(
|
||||
cache_key: str,
|
||||
df: pd.DataFrame,
|
||||
query: str,
|
||||
cached_dttm: str,
|
||||
cache_timeout: int,
|
||||
datasource_uid: Optional[str],
|
||||
annotation_data: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
cache_value = dict(
|
||||
dttm=cached_dttm, df=df, query=query, annotation_data=annotation_data or {}
|
||||
)
|
||||
stats_logger.incr("set_cache_key")
|
||||
cache_manager.data_cache.set(cache_key, cache_value, timeout=cache_timeout)
|
||||
|
||||
if datasource_uid:
|
||||
ck = CacheKey(
|
||||
cache_key=cache_key,
|
||||
cache_timeout=cache_timeout,
|
||||
datasource_uid=datasource_uid,
|
||||
)
|
||||
db.session.add(ck)
|
||||
except Exception as ex:
|
||||
# cache.set call can fail if the backend is down or if
|
||||
# the key is too large or whatever other reasons
|
||||
logger.warning("Could not cache key {}".format(cache_key))
|
||||
logger.exception(ex)
|
||||
cache_manager.data_cache.delete(cache_key)
|
||||
|
||||
|
||||
class BaseViz:
|
||||
|
||||
"""All visualizations derive this base class"""
|
||||
|
@ -144,6 +116,7 @@ class BaseViz:
|
|||
datasource: "BaseDatasource",
|
||||
form_data: Dict[str, Any],
|
||||
force: bool = False,
|
||||
force_cached: bool = False,
|
||||
) -> None:
|
||||
if not datasource:
|
||||
raise QueryObjectValidationError(_("Viz is missing a datasource"))
|
||||
|
@ -164,6 +137,7 @@ class BaseViz:
|
|||
self.results: Optional[QueryResult] = None
|
||||
self.errors: List[Dict[str, Any]] = []
|
||||
self.force = force
|
||||
self._force_cached = force_cached
|
||||
self.from_dttm: Optional[datetime] = None
|
||||
self.to_dttm: Optional[datetime] = None
|
||||
|
||||
|
@ -180,6 +154,10 @@ class BaseViz:
|
|||
self.applied_filters: List[Dict[str, str]] = []
|
||||
self.rejected_filters: List[Dict[str, str]] = []
|
||||
|
||||
@property
|
||||
def force_cached(self) -> bool:
|
||||
return self._force_cached
|
||||
|
||||
def process_metrics(self) -> None:
|
||||
# metrics in TableViz is order sensitive, so metric_dict should be
|
||||
# OrderedDict
|
||||
|
@ -272,7 +250,7 @@ class BaseViz:
|
|||
"columns": [o.column_name for o in self.datasource.columns],
|
||||
}
|
||||
)
|
||||
df = self.get_df(query_obj)
|
||||
df = self.get_df_payload(query_obj)["df"] # leverage caching logic
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
|
||||
|
@ -518,7 +496,6 @@ class BaseViz:
|
|||
is_loaded = False
|
||||
stacktrace = None
|
||||
df = None
|
||||
cached_dttm = datetime.utcnow().isoformat().split(".")[0]
|
||||
if cache_key and cache_manager.data_cache and not self.force:
|
||||
cache_value = cache_manager.data_cache.get(cache_key)
|
||||
if cache_value:
|
||||
|
@ -539,6 +516,11 @@ class BaseViz:
|
|||
logger.info("Serving from cache")
|
||||
|
||||
if query_obj and not is_loaded:
|
||||
if self.force_cached:
|
||||
logger.warning(
|
||||
f"force_cached (viz.py): value not found for cache key {cache_key}"
|
||||
)
|
||||
raise CacheLoadError(_("Cached value not found"))
|
||||
try:
|
||||
invalid_columns = [
|
||||
col
|
||||
|
@ -590,12 +572,11 @@ class BaseViz:
|
|||
|
||||
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
|
||||
set_and_log_cache(
|
||||
cache_key=cache_key,
|
||||
df=df,
|
||||
query=self.query,
|
||||
cached_dttm=cached_dttm,
|
||||
cache_timeout=self.cache_timeout,
|
||||
datasource_uid=self.datasource.uid,
|
||||
cache_manager.data_cache,
|
||||
cache_key,
|
||||
{"df": df, "query": self.query},
|
||||
self.cache_timeout,
|
||||
self.datasource.uid,
|
||||
)
|
||||
return {
|
||||
"cache_key": self._any_cache_key,
|
||||
|
@ -618,13 +599,15 @@ class BaseViz:
|
|||
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
|
||||
)
|
||||
|
||||
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
|
||||
has_error = (
|
||||
def has_error(self, payload: VizPayload) -> bool:
|
||||
return (
|
||||
payload.get("status") == utils.QueryStatus.FAILED
|
||||
or payload.get("error") is not None
|
||||
or bool(payload.get("errors"))
|
||||
)
|
||||
return self.json_dumps(payload), has_error
|
||||
|
||||
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
|
||||
return self.json_dumps(payload), self.has_error(payload)
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
|
@ -638,7 +621,7 @@ class BaseViz:
|
|||
return content
|
||||
|
||||
def get_csv(self) -> Optional[str]:
|
||||
df = self.get_df()
|
||||
df = self.get_df_payload()["df"] # leverage caching logic
|
||||
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||
return df.to_csv(index=include_index, **config["CSV_EXPORT"])
|
||||
|
||||
|
@ -1721,7 +1704,7 @@ class SunburstViz(BaseViz):
|
|||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
if df.empty:
|
||||
return None
|
||||
fd = self.form_data
|
||||
fd = copy.deepcopy(self.form_data)
|
||||
cols = fd.get("groupby") or []
|
||||
cols.extend(["m1", "m2"])
|
||||
metric = utils.get_metric_name(fd["metric"])
|
||||
|
@ -2983,12 +2966,14 @@ class PartitionViz(NVD3TimeSeriesViz):
|
|||
return self.nest_values(levels)
|
||||
|
||||
|
||||
def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]:
|
||||
return set(cls.__subclasses__()).union(
|
||||
[sc for c in cls.__subclasses__() for sc in get_subclasses(c)]
|
||||
)
|
||||
|
||||
|
||||
viz_types = {
|
||||
o.viz_type: o
|
||||
for o in globals().values()
|
||||
if (
|
||||
inspect.isclass(o)
|
||||
and issubclass(o, BaseViz)
|
||||
and o.viz_type not in config["VIZ_TYPE_DENYLIST"]
|
||||
)
|
||||
for o in get_subclasses(BaseViz)
|
||||
if o.viz_type not in config["VIZ_TYPE_DENYLIST"]
|
||||
}
|
||||
|
|
|
@ -57,13 +57,13 @@ from superset.extensions import cache_manager, security_manager
|
|||
from superset.models.helpers import QueryResult
|
||||
from superset.typing import QueryObjectDict, VizData, VizPayload
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
DTTM_ALIAS,
|
||||
JS_MAX_INTEGER,
|
||||
merge_extra_filters,
|
||||
to_adhoc,
|
||||
)
|
||||
from superset.viz import set_and_log_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
@ -518,10 +518,9 @@ class BaseViz:
|
|||
|
||||
if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED:
|
||||
set_and_log_cache(
|
||||
cache_manager.data_cache,
|
||||
cache_key,
|
||||
df,
|
||||
self.query,
|
||||
cached_dttm,
|
||||
{"df": df, "query": self.query},
|
||||
self.cache_timeout,
|
||||
self.datasource.uid,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,120 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from superset.extensions import async_query_manager
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
class TestAsyncEventApi(SupersetTestCase):
|
||||
UUID = "943c920-32a5-412a-977d-b8e47d36f5a4"
|
||||
|
||||
def fetch_events(self, last_id: Optional[str] = None):
|
||||
base_uri = "api/v1/async_event/"
|
||||
uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri
|
||||
return self.client.get(uri)
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events(self, mock_uuid4):
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
rv = self.fetch_events()
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
assert rv.status_code == 200
|
||||
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
|
||||
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_last_id(self, mock_uuid4):
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
rv = self.fetch_events("1607471525180-0")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
assert rv.status_code == 200
|
||||
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
|
||||
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_results(self, mock_uuid4):
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
mock_xrange.return_value = [
|
||||
(
|
||||
"1607477697866-0",
|
||||
{
|
||||
"data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463"}'
|
||||
},
|
||||
),
|
||||
(
|
||||
"1607477697993-0",
|
||||
{
|
||||
"data": '{"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f", "job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c", "user_id": "1", "status": "done", "errors": [], "result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d"}'
|
||||
},
|
||||
),
|
||||
]
|
||||
rv = self.fetch_events()
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
assert rv.status_code == 200
|
||||
channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID
|
||||
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
|
||||
expected = {
|
||||
"result": [
|
||||
{
|
||||
"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
|
||||
"errors": [],
|
||||
"id": "1607477697866-0",
|
||||
"job_id": "10a0bd9a-03c8-4737-9345-f4234ba86512",
|
||||
"result_url": "/api/v1/chart/data/qc-ecd766dd461f294e1bcdaa321e0e8463",
|
||||
"status": "done",
|
||||
"user_id": "1",
|
||||
},
|
||||
{
|
||||
"channel_id": "1095c1c9-b6b1-444d-aa83-8e323b32831f",
|
||||
"errors": [],
|
||||
"id": "1607477697993-0",
|
||||
"job_id": "027cbe49-26ce-4813-bb5a-0b95a626b84c",
|
||||
"result_url": "/api/v1/chart/data/qc-1bbc3a240e7039ba4791aefb3a7ee80d",
|
||||
"status": "done",
|
||||
"user_id": "1",
|
||||
},
|
||||
]
|
||||
}
|
||||
self.assertEqual(response, expected)
|
||||
|
||||
def test_events_no_login(self):
|
||||
async_query_manager.init_app(app)
|
||||
rv = self.fetch_events()
|
||||
assert rv.status_code == 401
|
||||
|
||||
def test_events_no_token(self):
|
||||
self.login(username="admin")
|
||||
self.client.set_cookie(
|
||||
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], ""
|
||||
)
|
||||
rv = self.fetch_events()
|
||||
assert rv.status_code == 401
|
|
@ -35,6 +35,7 @@ class TestCache(SupersetTestCase):
|
|||
cache_manager.data_cache.clear()
|
||||
|
||||
def test_no_data_cache(self):
|
||||
data_cache_config = app.config["DATA_CACHE_CONFIG"]
|
||||
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
|
||||
cache_manager.init_app(app)
|
||||
|
||||
|
@ -48,11 +49,15 @@ class TestCache(SupersetTestCase):
|
|||
resp_from_cache = self.get_json_resp(
|
||||
json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
|
||||
)
|
||||
# restore DATA_CACHE_CONFIG
|
||||
app.config["DATA_CACHE_CONFIG"] = data_cache_config
|
||||
self.assertFalse(resp["is_cached"])
|
||||
self.assertFalse(resp_from_cache["is_cached"])
|
||||
|
||||
def test_slice_data_cache(self):
|
||||
# Override cache config
|
||||
data_cache_config = app.config["DATA_CACHE_CONFIG"]
|
||||
cache_default_timeout = app.config["CACHE_DEFAULT_TIMEOUT"]
|
||||
app.config["CACHE_DEFAULT_TIMEOUT"] = 100
|
||||
app.config["DATA_CACHE_CONFIG"] = {
|
||||
"CACHE_TYPE": "simple",
|
||||
|
@ -87,5 +92,6 @@ class TestCache(SupersetTestCase):
|
|||
self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"]))
|
||||
|
||||
# reset cache config
|
||||
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "null"}
|
||||
app.config["DATA_CACHE_CONFIG"] = data_cache_config
|
||||
app.config["CACHE_DEFAULT_TIMEOUT"] = cache_default_timeout
|
||||
cache_manager.init_app(app)
|
||||
|
|
|
@ -31,21 +31,21 @@ from sqlalchemy import and_
|
|||
from sqlalchemy.sql import func
|
||||
|
||||
from tests.test_app import app
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.utils.core import AnnotationType, get_example_database
|
||||
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
|
||||
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.extensions import async_query_manager, cache_manager, db, security_manager
|
||||
from superset.models.annotations import AnnotationLayer
|
||||
from superset.models.core import Database, FavStar, FavStarClassName
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.reports import ReportSchedule, ReportScheduleType
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import AnnotationType, get_example_database
|
||||
|
||||
from tests.base_api_tests import ApiOwnersTestCaseMixin
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.base_tests import SupersetTestCase, post_assert_metric, test_client
|
||||
|
||||
from tests.fixtures.importexport import (
|
||||
chart_config,
|
||||
chart_metadata_config,
|
||||
|
@ -53,6 +53,7 @@ from tests.fixtures.importexport import (
|
|||
dataset_config,
|
||||
dataset_metadata_config,
|
||||
)
|
||||
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
|
||||
from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS
|
||||
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
|
||||
from tests.annotation_layers.fixtures import create_annotation_layers
|
||||
|
@ -99,6 +100,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
db.session.commit()
|
||||
return slice
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_data_cache(self):
|
||||
with app.app_context():
|
||||
cache_manager.data_cache.clear()
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def create_charts(self):
|
||||
with self.create_app().app_context():
|
||||
|
@ -1287,6 +1294,155 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
if get_example_database().backend != "presto":
|
||||
assert "('boy' = 'boy')" in result
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_chart_data_async(self):
|
||||
"""
|
||||
Chart data API: Test chart data query (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
keys = list(data.keys())
|
||||
self.assertCountEqual(
|
||||
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
|
||||
)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_chart_data_async_results_type(self):
|
||||
"""
|
||||
Chart data API: Test chart data query non-JSON format (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
request_payload["result_type"] = "results"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_chart_data_async_invalid_token(self):
|
||||
"""
|
||||
Chart data API: Test chart data query (async)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
request_payload = get_query_context(table.name, table.id, table.type)
|
||||
test_client.set_cookie(
|
||||
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
|
||||
)
|
||||
rv = post_assert_metric(test_client, CHART_DATA_URI, request_payload, "data")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
|
||||
def test_chart_data_cache(self, load_qc_mock):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
query_context = get_query_context(table.name, table.id, table.type)
|
||||
load_qc_mock.return_value = query_context
|
||||
orig_run = ChartDataCommand.run
|
||||
|
||||
def mock_run(self, **kwargs):
|
||||
assert kwargs["force_cached"] == True
|
||||
# override force_cached to get result from DB
|
||||
return orig_run(self, force_cached=False)
|
||||
|
||||
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["result"][0]["rowcount"], 45)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
|
||||
def test_chart_data_cache_run_failed(self, load_qc_mock):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request with run failure
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
query_context = get_query_context(table.name, table.id, table.type)
|
||||
load_qc_mock.return_value = query_context
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(data["message"], "Error loading data from cache")
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
|
||||
def test_chart_data_cache_no_login(self, load_qc_mock):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request (no login)
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
table = self.get_table_by_name("birth_names")
|
||||
query_context = get_query_context(table.name, table.id, table.type)
|
||||
load_qc_mock.return_value = query_context
|
||||
orig_run = ChartDataCommand.run
|
||||
|
||||
def mock_run(self, **kwargs):
|
||||
assert kwargs["force_cached"] == True
|
||||
# override force_cached to get result from DB
|
||||
return orig_run(self, force_cached=False)
|
||||
|
||||
with mock.patch.object(ChartDataCommand, "run", new=mock_run):
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_chart_data_cache_key_error(self):
|
||||
"""
|
||||
Chart data cache API: Test chart data async cache request with invalid cache key
|
||||
"""
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
rv = self.get_assert_metric(
|
||||
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
|
||||
)
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_export_chart(self):
|
||||
"""
|
||||
Chart API: Test export chart
|
||||
|
|
|
@ -52,6 +52,7 @@ from superset import (
|
|||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.extensions import async_query_manager
|
||||
from superset.models import core as models
|
||||
from superset.models.annotations import Annotation, AnnotationLayer
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -602,10 +603,13 @@ class TestCore(SupersetTestCase):
|
|||
) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
|
||||
|
||||
def test_cache_logging(self):
|
||||
store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
|
||||
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
|
||||
girls_slice = self.get_slice("Girls", db.session)
|
||||
self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id))
|
||||
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
|
||||
assert ck.datasource_uid == f"{girls_slice.table.id}__table"
|
||||
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
|
||||
|
||||
def test_shortner(self):
|
||||
self.login(username="admin")
|
||||
|
@ -841,6 +845,191 @@ class TestCore(SupersetTestCase):
|
|||
"The datasource associated with this chart no longer exists",
|
||||
)
|
||||
|
||||
def test_explore_json(self):
|
||||
tbl_id = self.table_ids.get("birth_names")
|
||||
form_data = {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{tbl_id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
self.login(username="admin")
|
||||
rv = self.client.post(
|
||||
"/superset/explore_json/", data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["rowcount"], 2)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_explore_json_async(self):
|
||||
tbl_id = self.table_ids.get("birth_names")
|
||||
form_data = {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{tbl_id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
rv = self.client.post(
|
||||
"/superset/explore_json/", data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
keys = list(data.keys())
|
||||
|
||||
self.assertEqual(rv.status_code, 202)
|
||||
self.assertCountEqual(
|
||||
keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"]
|
||||
)
|
||||
|
||||
@mock.patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
GLOBAL_ASYNC_QUERIES=True,
|
||||
)
|
||||
def test_explore_json_async_results_format(self):
|
||||
tbl_id = self.table_ids.get("birth_names")
|
||||
form_data = {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{tbl_id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
async_query_manager.init_app(app)
|
||||
self.login(username="admin")
|
||||
rv = self.client.post(
|
||||
"/superset/explore_json/?results=true",
|
||||
data={"form_data": json.dumps(form_data)},
|
||||
)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
@mock.patch(
|
||||
"superset.utils.cache_manager.CacheManager.cache",
|
||||
new_callable=mock.PropertyMock,
|
||||
)
|
||||
@mock.patch("superset.viz.BaseViz.force_cached", new_callable=mock.PropertyMock)
|
||||
def test_explore_json_data(self, mock_force_cached, mock_cache):
|
||||
tbl_id = self.table_ids.get("birth_names")
|
||||
form_data = dict(
|
||||
{
|
||||
"form_data": {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{tbl_id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
class MockCache:
|
||||
def get(self, key):
|
||||
return form_data
|
||||
|
||||
def set(self):
|
||||
return None
|
||||
|
||||
mock_cache.return_value = MockCache()
|
||||
mock_force_cached.return_value = False
|
||||
|
||||
self.login(username="admin")
|
||||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(data["rowcount"], 2)
|
||||
|
||||
@mock.patch(
|
||||
"superset.utils.cache_manager.CacheManager.cache",
|
||||
new_callable=mock.PropertyMock,
|
||||
)
|
||||
def test_explore_json_data_no_login(self, mock_cache):
|
||||
tbl_id = self.table_ids.get("birth_names")
|
||||
form_data = dict(
|
||||
{
|
||||
"form_data": {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{tbl_id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
class MockCache:
|
||||
def get(self, key):
|
||||
return form_data
|
||||
|
||||
def set(self):
|
||||
return None
|
||||
|
||||
mock_cache.return_value = MockCache()
|
||||
|
||||
rv = self.client.get("/superset/explore_json/data/valid-cache-key")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
def test_explore_json_data_invalid_cache_key(self):
|
||||
self.login(username="admin")
|
||||
cache_key = "invalid-cache-key"
|
||||
rv = self.client.get(f"/superset/explore_json/data/{cache_key}")
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
self.assertEqual(data["error"], "Cached data not found")
|
||||
|
||||
@mock.patch(
|
||||
"superset.security.SupersetSecurityManager.get_schemas_accessible_by_user"
|
||||
)
|
||||
|
|
|
@ -82,7 +82,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
# construct baseline cache_key
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
cache_key_original = query_context.query_cache_key(query_object)
|
||||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
|
@ -99,7 +99,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
# create new QueryContext with unchanged attributes and extract new cache_key
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_new = query_context.cache_key(query_object)
|
||||
cache_key_new = query_context.query_cache_key(query_object)
|
||||
|
||||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
|
@ -115,20 +115,20 @@ class TestQueryContext(SupersetTestCase):
|
|||
# construct baseline cache_key from query_context with post processing operation
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
cache_key_original = query_context.query_cache_key(query_object)
|
||||
|
||||
# ensure added None post_processing operation doesn't change cache_key
|
||||
payload["queries"][0]["post_processing"].append(None)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_with_null = query_context.cache_key(query_object)
|
||||
cache_key_with_null = query_context.query_cache_key(query_object)
|
||||
self.assertEqual(cache_key_original, cache_key_with_null)
|
||||
|
||||
# ensure query without post processing operation is different
|
||||
payload["queries"][0].pop("post_processing")
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_without_post_processing = query_context.cache_key(query_object)
|
||||
cache_key_without_post_processing = query_context.query_cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
|
||||
|
||||
def test_query_context_time_range_endpoints(self):
|
||||
|
@ -179,13 +179,10 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
data = responses["queries"][0]["data"]
|
||||
self.assertIn("name,sum__num\n", data)
|
||||
self.assertEqual(len(data.split("\n")), 12)
|
||||
|
||||
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
|
||||
assert ck.datasource_uid == f"{table.id}__table"
|
||||
|
||||
def test_sql_injection_via_groupby(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in groupby are caught
|
||||
|
@ -197,7 +194,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["groupby"] = ["currentDatabase()"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
assert query_payload["queries"][0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_columns(self):
|
||||
"""
|
||||
|
@ -212,7 +209,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload["queries"][0]["columns"] = ["*, 'extra'"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
assert query_payload["queries"][0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_metrics(self):
|
||||
"""
|
||||
|
@ -233,7 +230,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
assert query_payload["queries"][0].get("error") is not None
|
||||
|
||||
def test_samples_response_type(self):
|
||||
"""
|
||||
|
@ -248,7 +245,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
data = responses["queries"][0]["data"]
|
||||
self.assertIsInstance(data, list)
|
||||
self.assertEqual(len(data), 5)
|
||||
self.assertNotIn("sum__num", data[0])
|
||||
|
@ -265,7 +262,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
response = responses[0]
|
||||
response = responses["queries"][0]
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertEqual(response["language"], "sql")
|
||||
self.assertIn("SELECT", response["query"])
|
||||
|
|
|
@ -22,7 +22,7 @@ from tests.superset_test_custom_template_processors import CustomPrestoTemplateP
|
|||
|
||||
AUTH_USER_REGISTRATION_ROLE = "alpha"
|
||||
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
|
||||
DEBUG = True
|
||||
DEBUG = False
|
||||
SUPERSET_WEBSERVER_PORT = 8081
|
||||
|
||||
# Allowing SQLALCHEMY_DATABASE_URI and SQLALCHEMY_EXAMPLES_URI to be defined as an env vars for
|
||||
|
@ -96,6 +96,8 @@ DATA_CACHE_CONFIG = {
|
|||
"CACHE_KEY_PREFIX": "superset_data_cache",
|
||||
}
|
||||
|
||||
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me-test-secret-change-me"
|
||||
|
||||
|
||||
class CeleryConfig(object):
|
||||
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,132 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for async query celery jobs in Superset"""
|
||||
import re
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from superset import db
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.charts.commands.exceptions import ChartDataQueryFailedError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.extensions import async_query_manager
|
||||
from superset.tasks.async_queries import (
|
||||
load_chart_data_into_cache,
|
||||
load_explore_json_into_cache,
|
||||
)
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
def get_table_by_name(name: str) -> SqlaTable:
|
||||
with app.app_context():
|
||||
return db.session.query(SqlaTable).filter_by(table_name=name).one()
|
||||
|
||||
|
||||
class TestAsyncQueries(SupersetTestCase):
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache(self, mock_update_job):
|
||||
async_query_manager.init_app(app)
|
||||
table = get_table_by_name("birth_names")
|
||||
form_data = get_query_context(table.name, table.id, table.type)
|
||||
job_metadata = {
|
||||
"channel_id": str(uuid4()),
|
||||
"job_id": str(uuid4()),
|
||||
"user_id": 1,
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
load_chart_data_into_cache(job_metadata, form_data)
|
||||
|
||||
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
|
||||
|
||||
@mock.patch.object(
|
||||
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
|
||||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
|
||||
async_query_manager.init_app(app)
|
||||
table = get_table_by_name("birth_names")
|
||||
form_data = get_query_context(table.name, table.id, table.type)
|
||||
job_metadata = {
|
||||
"channel_id": str(uuid4()),
|
||||
"job_id": str(uuid4()),
|
||||
"user_id": 1,
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
}
|
||||
with pytest.raises(ChartDataQueryFailedError):
|
||||
load_chart_data_into_cache(job_metadata, form_data)
|
||||
|
||||
mock_run_command.assert_called_with(cache=True)
|
||||
errors = [{"message": "Error: foo"}]
|
||||
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
|
||||
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_explore_json_into_cache(self, mock_update_job):
|
||||
async_query_manager.init_app(app)
|
||||
table = get_table_by_name("birth_names")
|
||||
form_data = {
|
||||
"queryFields": {
|
||||
"metrics": "metrics",
|
||||
"groupby": "groupby",
|
||||
"columns": "groupby",
|
||||
},
|
||||
"datasource": f"{table.id}__table",
|
||||
"viz_type": "dist_bar",
|
||||
"time_range_endpoints": ["inclusive", "exclusive"],
|
||||
"granularity_sqla": "ds",
|
||||
"time_range": "No filter",
|
||||
"metrics": ["count"],
|
||||
"adhoc_filters": [],
|
||||
"groupby": ["gender"],
|
||||
"row_limit": 100,
|
||||
}
|
||||
job_metadata = {
|
||||
"channel_id": str(uuid4()),
|
||||
"job_id": str(uuid4()),
|
||||
"user_id": 1,
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
load_explore_json_into_cache(job_metadata, form_data)
|
||||
|
||||
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
|
||||
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_explore_json_into_cache_error(self, mock_update_job):
|
||||
async_query_manager.init_app(app)
|
||||
form_data = {}
|
||||
job_metadata = {
|
||||
"channel_id": str(uuid4()),
|
||||
"job_id": str(uuid4()),
|
||||
"user_id": 1,
|
||||
"status": "pending",
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
with pytest.raises(SupersetException):
|
||||
load_explore_json_into_cache(job_metadata, form_data)
|
||||
|
||||
errors = ["The datasource associated with this chart no longer exists"]
|
||||
mock_update_job.assert_called_with(job_metadata, "error", errors=errors)
|
|
@ -163,9 +163,20 @@ class TestBaseViz(SupersetTestCase):
|
|||
datasource.database.cache_timeout = 1666
|
||||
self.assertEqual(1666, test_viz.cache_timeout)
|
||||
|
||||
datasource.database.cache_timeout = None
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"],
|
||||
test_viz.cache_timeout,
|
||||
)
|
||||
|
||||
data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None
|
||||
datasource.database.cache_timeout = None
|
||||
test_viz = viz.BaseViz(datasource, form_data={})
|
||||
self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout)
|
||||
# restore DATA_CACHE_CONFIG timeout
|
||||
app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout
|
||||
|
||||
|
||||
class TestTableViz(SupersetTestCase):
|
||||
|
|
Loading…
Reference in New Issue