[format] Using Black (#7769)

This commit is contained in:
John Bodley 2019-06-25 13:34:48 -07:00 committed by GitHub
parent 0c9e6d0985
commit 5c58fd1802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
270 changed files with 15592 additions and 14772 deletions

22
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,22 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
repos:
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.6

View File

@ -81,7 +81,7 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have # --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes # no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W" # --disable=W"
disable=standarderror-builtin,long-builtin,dict-view-method,intern-builtin,suppressed-message,no-absolute-import,unpacking-in-except,apply-builtin,delslice-method,indexing-exception,old-raise-syntax,print-statement,cmp-builtin,reduce-builtin,useless-suppression,coerce-method,input-builtin,cmp-method,raw_input-builtin,nonzero-method,backtick,basestring-builtin,setslice-method,reload-builtin,oct-method,map-builtin-not-iterating,execfile-builtin,old-octal-literal,zip-builtin-not-iterating,buffer-builtin,getslice-method,metaclass-assignment,xrange-builtin,long-suffix,round-builtin,range-builtin-not-iterating,next-method-called,dict-iter-method,parameter-unpacking,unicode-builtin,unichr-builtin,import-star-module-level,raising-string,filter-builtin-not-iterating,old-ne-operator,using-cmp-argument,coerce-builtin,file-builtin,old-division,hex-method,invalid-unary-operand-type,missing-docstring,too-many-lines,duplicate-code disable=standarderror-builtin,long-builtin,dict-view-method,intern-builtin,suppressed-message,no-absolute-import,unpacking-in-except,apply-builtin,delslice-method,indexing-exception,old-raise-syntax,print-statement,cmp-builtin,reduce-builtin,useless-suppression,coerce-method,input-builtin,cmp-method,raw_input-builtin,nonzero-method,backtick,basestring-builtin,setslice-method,reload-builtin,oct-method,map-builtin-not-iterating,execfile-builtin,old-octal-literal,zip-builtin-not-iterating,buffer-builtin,getslice-method,metaclass-assignment,xrange-builtin,long-suffix,round-builtin,range-builtin-not-iterating,next-method-called,dict-iter-method,parameter-unpacking,unicode-builtin,unichr-builtin,import-star-module-level,raising-string,filter-builtin-not-iterating,old-ne-operator,using-cmp-argument,coerce-builtin,file-builtin,old-division,hex-method,invalid-unary-operand-type,missing-docstring,too-many-lines,duplicate-code,bad-continuation
[REPORTS] [REPORTS]
@ -209,7 +209,7 @@ max-nested-blocks=5
[FORMAT] [FORMAT]
# Maximum number of characters on a single line. # Maximum number of characters on a single line.
max-line-length=90 max-line-length=88
# Regexp for a line that is allowed to be longer than the limit. # Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$ ignore-long-lines=^\s*(# )?<?https?://\S+>?$

View File

@ -296,9 +296,9 @@ python setup.py build_sphinx
#### OS Dependencies #### OS Dependencies
Make sure your machine meets the [OS dependencies](https://superset.incubator.apache.org/installation.html#os-dependencies) before following these steps. Make sure your machine meets the [OS dependencies](https://superset.incubator.apache.org/installation.html#os-dependencies) before following these steps.
Developers should use a virtualenv. Developers should use a virtualenv.
``` ```
pip install virtualenv pip install virtualenv
@ -447,6 +447,15 @@ export enum FeatureFlag {
those specified under FEATURE_FLAGS in `superset_config.py`. For example, `DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False }` in `superset/config.py` and `FEATURE_FLAGS = { 'BAR': True, 'BAZ': True }` in `superset_config.py` will result those specified under FEATURE_FLAGS in `superset_config.py`. For example, `DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False }` in `superset/config.py` and `FEATURE_FLAGS = { 'BAR': True, 'BAZ': True }` in `superset_config.py` will result
in combined feature flags of `{ 'FOO': True, 'BAR': True, 'BAZ': True }`. in combined feature flags of `{ 'FOO': True, 'BAR': True, 'BAZ': True }`.
## Git Hooks
Superset uses Git pre-commit hooks courtesy of [pre-commit](https://pre-commit.com/). To install run the following:
```bash
pip3 install -r requirements-dev.txt
pre-commit install
```
## Linting ## Linting
Lint the project with: Lint the project with:
@ -461,6 +470,10 @@ npm ci
npm run lint npm run lint
``` ```
The Python code is auto-formatted using [Black](https://github.com/python/black) which
is configured as a pre-commit hook. There are also numerous [editor integrations](https://black.readthedocs.io/en/stable/editor_integration.html).
## Testing ## Testing
### Python Testing ### Python Testing
@ -736,7 +749,7 @@ to work on `async` related features.
To do this, you'll need to: To do this, you'll need to:
* Add an additional database entry. We recommend you copy the connection * Add an additional database entry. We recommend you copy the connection
string from the database labeled `main`, and then enable `SQL Lab` and the string from the database labeled `main`, and then enable `SQL Lab` and the
features you want to use. Don't forget to check the `Async` box features you want to use. Don't forget to check the `Async` box
* Configure a results backend, here's a local `FileSystemCache` example, * Configure a results backend, here's a local `FileSystemCache` example,
not recommended for production, not recommended for production,

View File

@ -14,17 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
black==19.3b0
coverage==4.5.3 coverage==4.5.3
flake8-commas==2.0.0
flake8-import-order==0.18.1 flake8-import-order==0.18.1
flake8-mypy==17.8.0 flake8-mypy==17.8.0
flake8-quotes==2.0.1
flake8==3.7.7 flake8==3.7.7
flask-cors==3.0.7 flask-cors==3.0.7
ipdb==0.12 ipdb==0.12
mypy==0.670 mypy==0.670
nose==1.3.7 nose==1.3.7
pip-tools==3.7.0 pip-tools==3.7.0
pre-commit==1.17.0
psycopg2-binary==2.7.5 psycopg2-binary==2.7.5
pycodestyle==2.5.0 pycodestyle==2.5.0
pyhive==0.6.1 pyhive==0.6.1

145
setup.py
View File

@ -23,113 +23,100 @@ import sys
from setuptools import find_packages, setup from setuptools import find_packages, setup
if sys.version_info < (3, 6): if sys.version_info < (3, 6):
sys.exit('Sorry, Python < 3.6 is not supported') sys.exit("Sorry, Python < 3.6 is not supported")
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) BASE_DIR = os.path.abspath(os.path.dirname(__file__))
PACKAGE_DIR = os.path.join(BASE_DIR, 'superset', 'static', 'assets') PACKAGE_DIR = os.path.join(BASE_DIR, "superset", "static", "assets")
PACKAGE_FILE = os.path.join(PACKAGE_DIR, 'package.json') PACKAGE_FILE = os.path.join(PACKAGE_DIR, "package.json")
with open(PACKAGE_FILE) as package_file: with open(PACKAGE_FILE) as package_file:
version_string = json.load(package_file)['version'] version_string = json.load(package_file)["version"]
with io.open('README.md', encoding='utf-8') as f: with io.open("README.md", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
def get_git_sha(): def get_git_sha():
try: try:
s = subprocess.check_output(['git', 'rev-parse', 'HEAD']) s = subprocess.check_output(["git", "rev-parse", "HEAD"])
return s.decode().strip() return s.decode().strip()
except Exception: except Exception:
return '' return ""
GIT_SHA = get_git_sha() GIT_SHA = get_git_sha()
version_info = { version_info = {"GIT_SHA": GIT_SHA, "version": version_string}
'GIT_SHA': GIT_SHA, print("-==-" * 15)
'version': version_string, print("VERSION: " + version_string)
} print("GIT SHA: " + GIT_SHA)
print('-==-' * 15) print("-==-" * 15)
print('VERSION: ' + version_string)
print('GIT SHA: ' + GIT_SHA)
print('-==-' * 15)
with open(os.path.join(PACKAGE_DIR, 'version_info.json'), 'w') as version_file: with open(os.path.join(PACKAGE_DIR, "version_info.json"), "w") as version_file:
json.dump(version_info, version_file) json.dump(version_info, version_file)
setup( setup(
name='apache-superset', name="apache-superset",
description=( description=("A modern, enterprise-ready business intelligence web application"),
'A modern, enterprise-ready business intelligence web application'),
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
version=version_string, version=version_string,
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
scripts=['superset/bin/superset'], scripts=["superset/bin/superset"],
install_requires=[ install_requires=[
'bleach>=3.0.2, <4.0.0', "bleach>=3.0.2, <4.0.0",
'celery>=4.2.0, <5.0.0', "celery>=4.2.0, <5.0.0",
'click>=6.0, <7.0.0', # `click`>=7 forces "-" instead of "_" "click>=6.0, <7.0.0", # `click`>=7 forces "-" instead of "_"
'colorama', "colorama",
'contextlib2', "contextlib2",
'croniter>=0.3.28', "croniter>=0.3.28",
'cryptography>=2.4.2', "cryptography>=2.4.2",
'flask>=1.0.0, <2.0.0', "flask>=1.0.0, <2.0.0",
'flask-appbuilder>=2.1.5, <2.3.0', "flask-appbuilder>=2.1.5, <2.3.0",
'flask-caching', "flask-caching",
'flask-compress', "flask-compress",
'flask-talisman', "flask-talisman",
'flask-migrate', "flask-migrate",
'flask-wtf', "flask-wtf",
'geopy', "geopy",
'gunicorn', # deprecated "gunicorn", # deprecated
'humanize', "humanize",
'idna', "idna",
'isodate', "isodate",
'markdown>=3.0', "markdown>=3.0",
'pandas>=0.18.0, <0.24.0', # `pandas`>=0.24.0 changes datetimelike API "pandas>=0.18.0, <0.24.0", # `pandas`>=0.24.0 changes datetimelike API
'parsedatetime', "parsedatetime",
'pathlib2', "pathlib2",
'polyline', "polyline",
'pydruid>=0.5.2', "pydruid>=0.5.2",
'python-dateutil', "python-dateutil",
'python-dotenv', "python-dotenv",
'python-geohash', "python-geohash",
'pyyaml>=5.1', "pyyaml>=5.1",
'requests>=2.22.0', "requests>=2.22.0",
'retry>=0.9.2', "retry>=0.9.2",
'selenium>=3.141.0', "selenium>=3.141.0",
'simplejson>=3.15.0', "simplejson>=3.15.0",
'sqlalchemy>=1.3.5,<2.0', "sqlalchemy>=1.3.5,<2.0",
'sqlalchemy-utils>=0.33.2', "sqlalchemy-utils>=0.33.2",
'sqlparse', "sqlparse",
'wtforms-json', "wtforms-json",
], ],
extras_require={ extras_require={
'bigquery': [ "bigquery": ["pybigquery>=0.4.10", "pandas_gbq>=0.10.0"],
'pybigquery>=0.4.10', "cors": ["flask-cors>=2.0.0"],
'pandas_gbq>=0.10.0', "gsheets": ["gsheetsdb>=0.1.9"],
], "hive": ["pyhive[hive]>=0.6.1", "tableschema", "thrift>=0.11.0, <1.0.0"],
'cors': ['flask-cors>=2.0.0'], "mysql": ["mysqlclient==1.4.2.post1"],
'gsheets': ['gsheetsdb>=0.1.9'], "postgres": ["psycopg2-binary==2.7.5"],
'hive': [ "presto": ["pyhive[presto]>=0.4.0"],
'pyhive[hive]>=0.6.1',
'tableschema',
'thrift>=0.11.0, <1.0.0',
],
'mysql': ['mysqlclient==1.4.2.post1'],
'postgres': ['psycopg2-binary==2.7.5'],
'presto': ['pyhive[presto]>=0.4.0'],
}, },
author='Apache Software Foundation', author="Apache Software Foundation",
author_email='dev@superset.incubator.apache.org', author_email="dev@superset.incubator.apache.org",
url='https://superset.apache.org/', url="https://superset.apache.org/",
download_url=( download_url=(
'https://dist.apache.org/repos/dist/release/superset/' + version_string "https://dist.apache.org/repos/dist/release/superset/" + version_string
), ),
classifiers=[ classifiers=["Programming Language :: Python :: 3.6"],
'Programming Language :: Python :: 3.6',
],
) )

View File

@ -40,7 +40,7 @@ from superset.utils.core import pessimistic_connection_handling, setup_cache
wtforms_json.init() wtforms_json.init()
APP_DIR = os.path.dirname(__file__) APP_DIR = os.path.dirname(__file__)
CONFIG_MODULE = os.environ.get('SUPERSET_CONFIG', 'superset.config') CONFIG_MODULE = os.environ.get("SUPERSET_CONFIG", "superset.config")
if not os.path.exists(config.DATA_DIR): if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR) os.makedirs(config.DATA_DIR)
@ -52,18 +52,18 @@ conf = app.config
################################################################# #################################################################
# Handling manifest file logic at app start # Handling manifest file logic at app start
################################################################# #################################################################
MANIFEST_FILE = APP_DIR + '/static/assets/dist/manifest.json' MANIFEST_FILE = APP_DIR + "/static/assets/dist/manifest.json"
manifest = {} manifest = {}
def parse_manifest_json(): def parse_manifest_json():
global manifest global manifest
try: try:
with open(MANIFEST_FILE, 'r') as f: with open(MANIFEST_FILE, "r") as f:
# the manifest inclues non-entry files # the manifest inclues non-entry files
# we only need entries in templates # we only need entries in templates
full_manifest = json.load(f) full_manifest = json.load(f)
manifest = full_manifest.get('entrypoints', {}) manifest = full_manifest.get("entrypoints", {})
except Exception: except Exception:
pass pass
@ -72,14 +72,14 @@ def get_js_manifest_files(filename):
if app.debug: if app.debug:
parse_manifest_json() parse_manifest_json()
entry_files = manifest.get(filename, {}) entry_files = manifest.get(filename, {})
return entry_files.get('js', []) return entry_files.get("js", [])
def get_css_manifest_files(filename): def get_css_manifest_files(filename):
if app.debug: if app.debug:
parse_manifest_json() parse_manifest_json()
entry_files = manifest.get(filename, {}) entry_files = manifest.get(filename, {})
return entry_files.get('css', []) return entry_files.get("css", [])
def get_unloaded_chunks(files, loaded_chunks): def get_unloaded_chunks(files, loaded_chunks):
@ -104,16 +104,16 @@ def get_manifest():
################################################################# #################################################################
for bp in conf.get('BLUEPRINTS'): for bp in conf.get("BLUEPRINTS"):
try: try:
print("Registering blueprint: '{}'".format(bp.name)) print("Registering blueprint: '{}'".format(bp.name))
app.register_blueprint(bp) app.register_blueprint(bp)
except Exception as e: except Exception as e:
print('blueprint registration failed') print("blueprint registration failed")
logging.exception(e) logging.exception(e)
if conf.get('SILENCE_FAB'): if conf.get("SILENCE_FAB"):
logging.getLogger('flask_appbuilder').setLevel(logging.ERROR) logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
if app.debug: if app.debug:
app.logger.setLevel(logging.DEBUG) # pylint: disable=no-member app.logger.setLevel(logging.DEBUG) # pylint: disable=no-member
@ -121,44 +121,46 @@ else:
# In production mode, add log handler to sys.stderr. # In production mode, add log handler to sys.stderr.
app.logger.addHandler(logging.StreamHandler()) # pylint: disable=no-member app.logger.addHandler(logging.StreamHandler()) # pylint: disable=no-member
app.logger.setLevel(logging.INFO) # pylint: disable=no-member app.logger.setLevel(logging.INFO) # pylint: disable=no-member
logging.getLogger('pyhive.presto').setLevel(logging.INFO) logging.getLogger("pyhive.presto").setLevel(logging.INFO)
db = SQLA(app) db = SQLA(app)
if conf.get('WTF_CSRF_ENABLED'): if conf.get("WTF_CSRF_ENABLED"):
csrf = CSRFProtect(app) csrf = CSRFProtect(app)
csrf_exempt_list = conf.get('WTF_CSRF_EXEMPT_LIST', []) csrf_exempt_list = conf.get("WTF_CSRF_EXEMPT_LIST", [])
for ex in csrf_exempt_list: for ex in csrf_exempt_list:
csrf.exempt(ex) csrf.exempt(ex)
pessimistic_connection_handling(db.engine) pessimistic_connection_handling(db.engine)
cache = setup_cache(app, conf.get('CACHE_CONFIG')) cache = setup_cache(app, conf.get("CACHE_CONFIG"))
tables_cache = setup_cache(app, conf.get('TABLE_NAMES_CACHE_CONFIG')) tables_cache = setup_cache(app, conf.get("TABLE_NAMES_CACHE_CONFIG"))
migrate = Migrate(app, db, directory=APP_DIR + '/migrations') migrate = Migrate(app, db, directory=APP_DIR + "/migrations")
# Logging configuration # Logging configuration
logging.basicConfig(format=app.config.get('LOG_FORMAT')) logging.basicConfig(format=app.config.get("LOG_FORMAT"))
logging.getLogger().setLevel(app.config.get('LOG_LEVEL')) logging.getLogger().setLevel(app.config.get("LOG_LEVEL"))
if app.config.get('ENABLE_TIME_ROTATE'): if app.config.get("ENABLE_TIME_ROTATE"):
logging.getLogger().setLevel(app.config.get('TIME_ROTATE_LOG_LEVEL')) logging.getLogger().setLevel(app.config.get("TIME_ROTATE_LOG_LEVEL"))
handler = TimedRotatingFileHandler( handler = TimedRotatingFileHandler(
app.config.get('FILENAME'), app.config.get("FILENAME"),
when=app.config.get('ROLLOVER'), when=app.config.get("ROLLOVER"),
interval=app.config.get('INTERVAL'), interval=app.config.get("INTERVAL"),
backupCount=app.config.get('BACKUP_COUNT')) backupCount=app.config.get("BACKUP_COUNT"),
)
logging.getLogger().addHandler(handler) logging.getLogger().addHandler(handler)
if app.config.get('ENABLE_CORS'): if app.config.get("ENABLE_CORS"):
from flask_cors import CORS from flask_cors import CORS
CORS(app, **app.config.get('CORS_OPTIONS'))
if app.config.get('ENABLE_PROXY_FIX'): CORS(app, **app.config.get("CORS_OPTIONS"))
if app.config.get("ENABLE_PROXY_FIX"):
app.wsgi_app = ProxyFix(app.wsgi_app) app.wsgi_app = ProxyFix(app.wsgi_app)
if app.config.get('ENABLE_CHUNK_ENCODING'): if app.config.get("ENABLE_CHUNK_ENCODING"):
class ChunkedEncodingFix(object): class ChunkedEncodingFix(object):
def __init__(self, app): def __init__(self, app):
@ -167,40 +169,41 @@ if app.config.get('ENABLE_CHUNK_ENCODING'):
def __call__(self, environ, start_response): def __call__(self, environ, start_response):
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end. # content-length and read the stream till the end.
if environ.get('HTTP_TRANSFER_ENCODING', '').lower() == u'chunked': if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == u"chunked":
environ['wsgi.input_terminated'] = True environ["wsgi.input_terminated"] = True
return self.app(environ, start_response) return self.app(environ, start_response)
app.wsgi_app = ChunkedEncodingFix(app.wsgi_app) app.wsgi_app = ChunkedEncodingFix(app.wsgi_app)
if app.config.get('UPLOAD_FOLDER'): if app.config.get("UPLOAD_FOLDER"):
try: try:
os.makedirs(app.config.get('UPLOAD_FOLDER')) os.makedirs(app.config.get("UPLOAD_FOLDER"))
except OSError: except OSError:
pass pass
for middleware in app.config.get('ADDITIONAL_MIDDLEWARE'): for middleware in app.config.get("ADDITIONAL_MIDDLEWARE"):
app.wsgi_app = middleware(app.wsgi_app) app.wsgi_app = middleware(app.wsgi_app)
class MyIndexView(IndexView): class MyIndexView(IndexView):
@expose('/') @expose("/")
def index(self): def index(self):
return redirect('/superset/welcome') return redirect("/superset/welcome")
custom_sm = app.config.get('CUSTOM_SECURITY_MANAGER') or SupersetSecurityManager custom_sm = app.config.get("CUSTOM_SECURITY_MANAGER") or SupersetSecurityManager
if not issubclass(custom_sm, SupersetSecurityManager): if not issubclass(custom_sm, SupersetSecurityManager):
raise Exception( raise Exception(
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
not FAB's security manager. not FAB's security manager.
See [4565] in UPDATING.md""") See [4565] in UPDATING.md"""
)
with app.app_context(): with app.app_context():
appbuilder = AppBuilder( appbuilder = AppBuilder(
app, app,
db.session, db.session,
base_template='superset/base.html', base_template="superset/base.html",
indexview=MyIndexView, indexview=MyIndexView,
security_manager_class=custom_sm, security_manager_class=custom_sm,
update_perms=False, # Run `superset init` to update FAB's perms update_perms=False, # Run `superset init` to update FAB's perms
@ -208,15 +211,15 @@ with app.app_context():
security_manager = appbuilder.sm security_manager = appbuilder.sm
results_backend = app.config.get('RESULTS_BACKEND') results_backend = app.config.get("RESULTS_BACKEND")
# Merge user defined feature flags with default feature flags # Merge user defined feature flags with default feature flags
_feature_flags = app.config.get('DEFAULT_FEATURE_FLAGS') or {} _feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
_feature_flags.update(app.config.get('FEATURE_FLAGS') or {}) _feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
def get_feature_flags(): def get_feature_flags():
GET_FEATURE_FLAGS_FUNC = app.config.get('GET_FEATURE_FLAGS_FUNC') GET_FEATURE_FLAGS_FUNC = app.config.get("GET_FEATURE_FLAGS_FUNC")
if GET_FEATURE_FLAGS_FUNC: if GET_FEATURE_FLAGS_FUNC:
return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags)) return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags))
return _feature_flags return _feature_flags
@ -228,22 +231,22 @@ def is_feature_enabled(feature):
# Flask-Compress # Flask-Compress
if conf.get('ENABLE_FLASK_COMPRESS'): if conf.get("ENABLE_FLASK_COMPRESS"):
Compress(app) Compress(app)
if app.config['TALISMAN_ENABLED']: if app.config["TALISMAN_ENABLED"]:
talisman_config = app.config.get('TALISMAN_CONFIG') talisman_config = app.config.get("TALISMAN_CONFIG")
Talisman(app, **talisman_config) Talisman(app, **talisman_config)
# Hook that provides administrators a handle on the Flask APP # Hook that provides administrators a handle on the Flask APP
# after initialization # after initialization
flask_app_mutator = app.config.get('FLASK_APP_MUTATOR') flask_app_mutator = app.config.get("FLASK_APP_MUTATOR")
if flask_app_mutator: if flask_app_mutator:
flask_app_mutator(app) flask_app_mutator(app)
from superset import views # noqa from superset import views # noqa
# Registering sources # Registering sources
module_datasource_map = app.config.get('DEFAULT_MODULE_DS_MAP') module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get('ADDITIONAL_MODULE_DS_MAP')) module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
ConnectorRegistry.register_sources(module_datasource_map) ConnectorRegistry.register_sources(module_datasource_map)

View File

@ -26,11 +26,8 @@ from colorama import Fore, Style
from pathlib2 import Path from pathlib2 import Path
import yaml import yaml
from superset import ( from superset import app, appbuilder, data, db, security_manager
app, appbuilder, data, db, security_manager, from superset.utils import core as utils, dashboard_import_export, dict_import_export
)
from superset.utils import (
core as utils, dashboard_import_export, dict_import_export)
config = app.config config = app.config
celery_app = utils.get_celery_app(config) celery_app = utils.get_celery_app(config)
@ -54,114 +51,128 @@ def init():
@app.cli.command() @app.cli.command()
@click.option('--verbose', '-v', is_flag=True, help='Show extra information') @click.option("--verbose", "-v", is_flag=True, help="Show extra information")
def version(verbose): def version(verbose):
"""Prints the current version number""" """Prints the current version number"""
print(Fore.BLUE + '-=' * 15) print(Fore.BLUE + "-=" * 15)
print(Fore.YELLOW + 'Superset ' + Fore.CYAN + '{version}'.format( print(
version=config.get('VERSION_STRING'))) Fore.YELLOW
print(Fore.BLUE + '-=' * 15) + "Superset "
+ Fore.CYAN
+ "{version}".format(version=config.get("VERSION_STRING"))
)
print(Fore.BLUE + "-=" * 15)
if verbose: if verbose:
print('[DB] : ' + '{}'.format(db.engine)) print("[DB] : " + "{}".format(db.engine))
print(Style.RESET_ALL) print(Style.RESET_ALL)
def load_examples_run(load_test_data): def load_examples_run(load_test_data):
print('Loading examples into {}'.format(db)) print("Loading examples into {}".format(db))
data.load_css_templates() data.load_css_templates()
print('Loading energy related dataset') print("Loading energy related dataset")
data.load_energy() data.load_energy()
print("Loading [World Bank's Health Nutrition and Population Stats]") print("Loading [World Bank's Health Nutrition and Population Stats]")
data.load_world_bank_health_n_pop() data.load_world_bank_health_n_pop()
print('Loading [Birth names]') print("Loading [Birth names]")
data.load_birth_names() data.load_birth_names()
print('Loading [Unicode test data]') print("Loading [Unicode test data]")
data.load_unicode_test_data() data.load_unicode_test_data()
if not load_test_data: if not load_test_data:
print('Loading [Random time series data]') print("Loading [Random time series data]")
data.load_random_time_series_data() data.load_random_time_series_data()
print('Loading [Random long/lat data]') print("Loading [Random long/lat data]")
data.load_long_lat_data() data.load_long_lat_data()
print('Loading [Country Map data]') print("Loading [Country Map data]")
data.load_country_map_data() data.load_country_map_data()
print('Loading [Multiformat time series]') print("Loading [Multiformat time series]")
data.load_multiformat_time_series() data.load_multiformat_time_series()
print('Loading [Paris GeoJson]') print("Loading [Paris GeoJson]")
data.load_paris_iris_geojson() data.load_paris_iris_geojson()
print('Loading [San Francisco population polygons]') print("Loading [San Francisco population polygons]")
data.load_sf_population_polygons() data.load_sf_population_polygons()
print('Loading [Flights data]') print("Loading [Flights data]")
data.load_flights() data.load_flights()
print('Loading [BART lines]') print("Loading [BART lines]")
data.load_bart_lines() data.load_bart_lines()
print('Loading [Multi Line]') print("Loading [Multi Line]")
data.load_multi_line() data.load_multi_line()
print('Loading [Misc Charts] dashboard') print("Loading [Misc Charts] dashboard")
data.load_misc_dashboard() data.load_misc_dashboard()
print('Loading DECK.gl demo') print("Loading DECK.gl demo")
data.load_deck_dash() data.load_deck_dash()
print('Loading [Tabbed dashboard]') print("Loading [Tabbed dashboard]")
data.load_tabbed_dashboard() data.load_tabbed_dashboard()
@app.cli.command() @app.cli.command()
@click.option('--load-test-data', '-t', is_flag=True, help='Load additional test data') @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
def load_examples(load_test_data): def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """ """Loads a set of Slices and Dashboards and a supporting dataset """
load_examples_run(load_test_data) load_examples_run(load_test_data)
@app.cli.command() @app.cli.command()
@click.option('--datasource', '-d', help='Specify which datasource name to load, if ' @click.option(
'omitted, all datasources will be refreshed') "--datasource",
@click.option('--merge', '-m', is_flag=True, default=False, "-d",
help="Specify using 'merge' property during operation. " help="Specify which datasource name to load, if "
'Default value is False.') "omitted, all datasources will be refreshed",
)
@click.option(
"--merge",
"-m",
is_flag=True,
default=False,
help="Specify using 'merge' property during operation. " "Default value is False.",
)
def refresh_druid(datasource, merge): def refresh_druid(datasource, merge):
"""Refresh druid datasources""" """Refresh druid datasources"""
session = db.session() session = db.session()
from superset.connectors.druid.models import DruidCluster from superset.connectors.druid.models import DruidCluster
for cluster in session.query(DruidCluster).all(): for cluster in session.query(DruidCluster).all():
try: try:
cluster.refresh_datasources(datasource_name=datasource, cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
merge_flag=merge)
except Exception as e: except Exception as e:
print( print("Error while processing cluster '{}'\n{}".format(cluster, str(e)))
"Error while processing cluster '{}'\n{}".format(
cluster, str(e)))
logging.exception(e) logging.exception(e)
cluster.metadata_last_refreshed = datetime.now() cluster.metadata_last_refreshed = datetime.now()
print( print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
'Refreshed metadata from cluster '
'[' + cluster.cluster_name + ']')
session.commit() session.commit()
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--path', '-p', "--path",
help='Path to a single JSON file or path containing multiple JSON files' "-p",
'files to import (*.json)') help="Path to a single JSON file or path containing multiple JSON files"
"files to import (*.json)",
)
@click.option( @click.option(
'--recursive', '-r', is_flag=True, default=False, "--recursive",
help='recursively search the path for json files') "-r",
is_flag=True,
default=False,
help="recursively search the path for json files",
)
def import_dashboards(path, recursive): def import_dashboards(path, recursive):
"""Import dashboards from JSON""" """Import dashboards from JSON"""
p = Path(path) p = Path(path)
@ -169,114 +180,135 @@ def import_dashboards(path, recursive):
if p.is_file(): if p.is_file():
files.append(p) files.append(p)
elif p.exists() and not recursive: elif p.exists() and not recursive:
files.extend(p.glob('*.json')) files.extend(p.glob("*.json"))
elif p.exists() and recursive: elif p.exists() and recursive:
files.extend(p.rglob('*.json')) files.extend(p.rglob("*.json"))
for f in files: for f in files:
logging.info('Importing dashboard from file %s', f) logging.info("Importing dashboard from file %s", f)
try: try:
with f.open() as data_stream: with f.open() as data_stream:
dashboard_import_export.import_dashboards( dashboard_import_export.import_dashboards(db.session, data_stream)
db.session, data_stream)
except Exception as e: except Exception as e:
logging.error('Error when importing dashboard from file %s', f) logging.error("Error when importing dashboard from file %s", f)
logging.error(e) logging.error(e)
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--dashboard-file', '-f', default=None, "--dashboard-file", "-f", default=None, help="Specify the the file to export to"
help='Specify the the file to export to') )
@click.option( @click.option(
'--print_stdout', '-p', is_flag=True, default=False, "--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout"
help='Print JSON to stdout') )
def export_dashboards(print_stdout, dashboard_file): def export_dashboards(print_stdout, dashboard_file):
"""Export dashboards to JSON""" """Export dashboards to JSON"""
data = dashboard_import_export.export_dashboards(db.session) data = dashboard_import_export.export_dashboards(db.session)
if print_stdout or not dashboard_file: if print_stdout or not dashboard_file:
print(data) print(data)
if dashboard_file: if dashboard_file:
logging.info('Exporting dashboards to %s', dashboard_file) logging.info("Exporting dashboards to %s", dashboard_file)
with open(dashboard_file, 'w') as data_stream: with open(dashboard_file, "w") as data_stream:
data_stream.write(data) data_stream.write(data)
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--path', '-p', "--path",
help='Path to a single YAML file or path containing multiple YAML ' "-p",
'files to import (*.yaml or *.yml)') help="Path to a single YAML file or path containing multiple YAML "
"files to import (*.yaml or *.yml)",
)
@click.option( @click.option(
'--sync', '-s', 'sync', default='', "--sync",
help='comma seperated list of element types to synchronize ' "-s",
'e.g. "metrics,columns" deletes metrics and columns in the DB ' "sync",
'that are not specified in the YAML file') default="",
help="comma seperated list of element types to synchronize "
'e.g. "metrics,columns" deletes metrics and columns in the DB '
"that are not specified in the YAML file",
)
@click.option( @click.option(
'--recursive', '-r', is_flag=True, default=False, "--recursive",
help='recursively search the path for yaml files') "-r",
is_flag=True,
default=False,
help="recursively search the path for yaml files",
)
def import_datasources(path, sync, recursive): def import_datasources(path, sync, recursive):
"""Import datasources from YAML""" """Import datasources from YAML"""
sync_array = sync.split(',') sync_array = sync.split(",")
p = Path(path) p = Path(path)
files = [] files = []
if p.is_file(): if p.is_file():
files.append(p) files.append(p)
elif p.exists() and not recursive: elif p.exists() and not recursive:
files.extend(p.glob('*.yaml')) files.extend(p.glob("*.yaml"))
files.extend(p.glob('*.yml')) files.extend(p.glob("*.yml"))
elif p.exists() and recursive: elif p.exists() and recursive:
files.extend(p.rglob('*.yaml')) files.extend(p.rglob("*.yaml"))
files.extend(p.rglob('*.yml')) files.extend(p.rglob("*.yml"))
for f in files: for f in files:
logging.info('Importing datasources from file %s', f) logging.info("Importing datasources from file %s", f)
try: try:
with f.open() as data_stream: with f.open() as data_stream:
dict_import_export.import_from_dict( dict_import_export.import_from_dict(
db.session, db.session, yaml.safe_load(data_stream), sync=sync_array
yaml.safe_load(data_stream), )
sync=sync_array)
except Exception as e: except Exception as e:
logging.error('Error when importing datasources from file %s', f) logging.error("Error when importing datasources from file %s", f)
logging.error(e) logging.error(e)
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--datasource-file', '-f', default=None, "--datasource-file", "-f", default=None, help="Specify the the file to export to"
help='Specify the the file to export to') )
@click.option( @click.option(
'--print_stdout', '-p', is_flag=True, default=False, "--print_stdout", "-p", is_flag=True, default=False, help="Print YAML to stdout"
help='Print YAML to stdout') )
@click.option( @click.option(
'--back-references', '-b', is_flag=True, default=False, "--back-references",
help='Include parent back references') "-b",
is_flag=True,
default=False,
help="Include parent back references",
)
@click.option( @click.option(
'--include-defaults', '-d', is_flag=True, default=False, "--include-defaults",
help='Include fields containing defaults') "-d",
def export_datasources(print_stdout, datasource_file, is_flag=True,
back_references, include_defaults): default=False,
help="Include fields containing defaults",
)
def export_datasources(
print_stdout, datasource_file, back_references, include_defaults
):
"""Export datasources to YAML""" """Export datasources to YAML"""
data = dict_import_export.export_to_dict( data = dict_import_export.export_to_dict(
session=db.session, session=db.session,
recursive=True, recursive=True,
back_references=back_references, back_references=back_references,
include_defaults=include_defaults) include_defaults=include_defaults,
)
if print_stdout or not datasource_file: if print_stdout or not datasource_file:
yaml.safe_dump(data, stdout, default_flow_style=False) yaml.safe_dump(data, stdout, default_flow_style=False)
if datasource_file: if datasource_file:
logging.info('Exporting datasources to %s', datasource_file) logging.info("Exporting datasources to %s", datasource_file)
with open(datasource_file, 'w') as data_stream: with open(datasource_file, "w") as data_stream:
yaml.safe_dump(data, data_stream, default_flow_style=False) yaml.safe_dump(data, data_stream, default_flow_style=False)
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--back-references', '-b', is_flag=True, default=False, "--back-references",
help='Include parent back references') "-b",
is_flag=True,
default=False,
help="Include parent back references",
)
def export_datasource_schema(back_references): def export_datasource_schema(back_references):
"""Export datasource YAML schema to stdout""" """Export datasource YAML schema to stdout"""
data = dict_import_export.export_schema_to_dict( data = dict_import_export.export_schema_to_dict(back_references=back_references)
back_references=back_references)
yaml.safe_dump(data, stdout, default_flow_style=False) yaml.safe_dump(data, stdout, default_flow_style=False)
@ -284,47 +316,49 @@ def export_datasource_schema(back_references):
def update_datasources_cache(): def update_datasources_cache():
"""Refresh sqllab datasources cache""" """Refresh sqllab datasources cache"""
from superset.models.core import Database from superset.models.core import Database
for database in db.session.query(Database).all(): for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch: if database.allow_multi_schema_metadata_fetch:
print('Fetching {} datasources ...'.format(database.name)) print("Fetching {} datasources ...".format(database.name))
try: try:
database.get_all_table_names_in_database( database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60) force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database( database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60) force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as e: except Exception as e:
print('{}'.format(str(e))) print("{}".format(str(e)))
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'--workers', '-w', "--workers", "-w", type=int, help="Number of celery server workers to fire up"
type=int, )
help='Number of celery server workers to fire up')
def worker(workers): def worker(workers):
"""Starts a Superset worker for async SQL query execution.""" """Starts a Superset worker for async SQL query execution."""
logging.info( logging.info(
"The 'superset worker' command is deprecated. Please use the 'celery " "The 'superset worker' command is deprecated. Please use the 'celery "
"worker' command instead.") "worker' command instead."
)
if workers: if workers:
celery_app.conf.update(CELERYD_CONCURRENCY=workers) celery_app.conf.update(CELERYD_CONCURRENCY=workers)
elif config.get('SUPERSET_CELERY_WORKERS'): elif config.get("SUPERSET_CELERY_WORKERS"):
celery_app.conf.update( celery_app.conf.update(
CELERYD_CONCURRENCY=config.get('SUPERSET_CELERY_WORKERS')) CELERYD_CONCURRENCY=config.get("SUPERSET_CELERY_WORKERS")
)
worker = celery_app.Worker(optimization='fair') worker = celery_app.Worker(optimization="fair")
worker.start() worker.start()
@app.cli.command() @app.cli.command()
@click.option( @click.option(
'-p', '--port', "-p", "--port", default="5555", help="Port on which to start the Flower process"
default='5555', )
help='Port on which to start the Flower process')
@click.option( @click.option(
'-a', '--address', "-a", "--address", default="localhost", help="Address on which to run the service"
default='localhost', )
help='Address on which to run the service')
def flower(port, address): def flower(port, address):
"""Runs a Celery Flower web server """Runs a Celery Flower web server
@ -332,18 +366,19 @@ def flower(port, address):
broker""" broker"""
BROKER_URL = celery_app.conf.BROKER_URL BROKER_URL = celery_app.conf.BROKER_URL
cmd = ( cmd = (
'celery flower ' "celery flower "
f'--broker={BROKER_URL} ' f"--broker={BROKER_URL} "
f'--port={port} ' f"--port={port} "
f'--address={address} ' f"--address={address} "
) )
logging.info( logging.info(
"The 'superset flower' command is deprecated. Please use the 'celery " "The 'superset flower' command is deprecated. Please use the 'celery "
"flower' command instead.") "flower' command instead."
print(Fore.GREEN + 'Starting a Celery Flower instance') )
print(Fore.BLUE + '-=' * 40) print(Fore.GREEN + "Starting a Celery Flower instance")
print(Fore.BLUE + "-=" * 40)
print(Fore.YELLOW + cmd) print(Fore.YELLOW + cmd)
print(Fore.BLUE + '-=' * 40) print(Fore.BLUE + "-=" * 40)
Popen(cmd, shell=True).wait() Popen(cmd, shell=True).wait()
@ -354,7 +389,7 @@ def load_test_users():
Syncs permissions for those users/roles Syncs permissions for those users/roles
""" """
print(Fore.GREEN + 'Loading a set of users for unit tests') print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run() load_test_users_run()
@ -364,51 +399,73 @@ def load_test_users_run():
Syncs permissions for those users/roles Syncs permissions for those users/roles
""" """
if config.get('TESTING'): if config.get("TESTING"):
security_manager.sync_role_definitions() security_manager.sync_role_definitions()
gamma_sqllab_role = security_manager.add_role('gamma_sqllab') gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
for perm in security_manager.find_role('Gamma').permissions: for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm) security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db() utils.get_or_create_main_db()
db_perm = utils.get_main_database(security_manager.get_session).perm db_perm = utils.get_main_database(security_manager.get_session).perm
security_manager.add_permission_view_menu('database_access', db_perm) security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu( db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name='database_access') view_menu_name=db_perm, permission_name="database_access"
)
gamma_sqllab_role.permissions.append(db_pvm) gamma_sqllab_role.permissions.append(db_pvm)
for perm in security_manager.find_role('sql_lab').permissions: for perm in security_manager.find_role("sql_lab").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm) security_manager.add_permission_role(gamma_sqllab_role, perm)
admin = security_manager.find_user('admin') admin = security_manager.find_user("admin")
if not admin: if not admin:
security_manager.add_user( security_manager.add_user(
'admin', 'admin', ' user', 'admin@fab.org', "admin",
security_manager.find_role('Admin'), "admin",
password='general') " user",
"admin@fab.org",
security_manager.find_role("Admin"),
password="general",
)
gamma = security_manager.find_user('gamma') gamma = security_manager.find_user("gamma")
if not gamma: if not gamma:
security_manager.add_user( security_manager.add_user(
'gamma', 'gamma', 'user', 'gamma@fab.org', "gamma",
security_manager.find_role('Gamma'), "gamma",
password='general') "user",
"gamma@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma2 = security_manager.find_user('gamma2') gamma2 = security_manager.find_user("gamma2")
if not gamma2: if not gamma2:
security_manager.add_user( security_manager.add_user(
'gamma2', 'gamma2', 'user', 'gamma2@fab.org', "gamma2",
security_manager.find_role('Gamma'), "gamma2",
password='general') "user",
"gamma2@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma_sqllab_user = security_manager.find_user('gamma_sqllab') gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
if not gamma_sqllab_user: if not gamma_sqllab_user:
security_manager.add_user( security_manager.add_user(
'gamma_sqllab', 'gamma_sqllab', 'user', 'gamma_sqllab@fab.org', "gamma_sqllab",
gamma_sqllab_role, password='general') "gamma_sqllab",
"user",
"gamma_sqllab@fab.org",
gamma_sqllab_role,
password="general",
)
alpha = security_manager.find_user('alpha') alpha = security_manager.find_user("alpha")
if not alpha: if not alpha:
security_manager.add_user( security_manager.add_user(
'alpha', 'alpha', 'user', 'alpha@fab.org', "alpha",
security_manager.find_role('Alpha'), "alpha",
password='general') "user",
"alpha@fab.org",
security_manager.find_role("Alpha"),
password="general",
)
security_manager.get_session.commit() security_manager.get_session.commit()

View File

@ -32,7 +32,7 @@ from superset.utils.core import DTTM_ALIAS
from .query_object import QueryObject from .query_object import QueryObject
config = app.config config = app.config
stats_logger = config.get('STATS_LOGGER') stats_logger = config.get("STATS_LOGGER")
class QueryContext: class QueryContext:
@ -41,21 +41,21 @@ class QueryContext:
to retrieve the data payload for a given viz. to retrieve the data payload for a given viz.
""" """
cache_type = 'df' cache_type = "df"
enforce_numerical_metrics = True enforce_numerical_metrics = True
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes # TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288 # a vanilla python type https://github.com/python/mypy/issues/5288
def __init__( def __init__(
self, self,
datasource: Dict, datasource: Dict,
queries: List[Dict], queries: List[Dict],
force: bool = False, force: bool = False,
custom_cache_timeout: int = None, custom_cache_timeout: int = None,
): ):
self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'), self.datasource = ConnectorRegistry.get_datasource(
int(datasource.get('id')), # noqa: E501, T400 datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400
db.session) )
self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries)) self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
self.force = force self.force = force
@ -72,7 +72,7 @@ class QueryContext:
# support multiple queries from different data source. # support multiple queries from different data source.
timestamp_format = None timestamp_format = None
if self.datasource.type == 'table': if self.datasource.type == "table":
dttm_col = self.datasource.get_col(query_object.granularity) dttm_col = self.datasource.get_col(query_object.granularity)
if dttm_col: if dttm_col:
timestamp_format = dttm_col.python_date_format timestamp_format = dttm_col.python_date_format
@ -88,12 +88,13 @@ class QueryContext:
# parsing logic # parsing logic
if df is not None and not df.empty: if df is not None and not df.empty:
if DTTM_ALIAS in df.columns: if DTTM_ALIAS in df.columns:
if timestamp_format in ('epoch_s', 'epoch_ms'): if timestamp_format in ("epoch_s", "epoch_ms"):
# Column has already been formatted as a timestamp. # Column has already been formatted as a timestamp.
df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(pd.Timestamp) df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(pd.Timestamp)
else: else:
df[DTTM_ALIAS] = pd.to_datetime( df[DTTM_ALIAS] = pd.to_datetime(
df[DTTM_ALIAS], utc=False, format=timestamp_format) df[DTTM_ALIAS], utc=False, format=timestamp_format
)
if self.datasource.offset: if self.datasource.offset:
df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset) df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset)
df[DTTM_ALIAS] += query_object.time_shift df[DTTM_ALIAS] += query_object.time_shift
@ -103,10 +104,10 @@ class QueryContext:
df.replace([np.inf, -np.inf], np.nan) df.replace([np.inf, -np.inf], np.nan)
return { return {
'query': result.query, "query": result.query,
'status': result.status, "status": result.status,
'error_message': result.error_message, "error_message": result.error_message,
'df': df, "df": df,
} }
def df_metrics_to_num(self, df, query_object): def df_metrics_to_num(self, df, query_object):
@ -114,23 +115,23 @@ class QueryContext:
metrics = [metric for metric in query_object.metrics] metrics = [metric for metric in query_object.metrics]
for col, dtype in df.dtypes.items(): for col, dtype in df.dtypes.items():
if dtype.type == np.object_ and col in metrics: if dtype.type == np.object_ and col in metrics:
df[col] = pd.to_numeric(df[col], errors='coerce') df[col] = pd.to_numeric(df[col], errors="coerce")
def get_data(self, df): def get_data(self, df):
return df.to_dict(orient='records') return df.to_dict(orient="records")
def get_single_payload(self, query_obj): def get_single_payload(self, query_obj):
"""Returns a payload of metadata and data""" """Returns a payload of metadata and data"""
payload = self.get_df_payload(query_obj) payload = self.get_df_payload(query_obj)
df = payload.get('df') df = payload.get("df")
status = payload.get('status') status = payload.get("status")
if status != utils.QueryStatus.FAILED: if status != utils.QueryStatus.FAILED:
if df is not None and df.empty: if df is not None and df.empty:
payload['error'] = 'No data' payload["error"] = "No data"
else: else:
payload['data'] = self.get_data(df) payload["data"] = self.get_data(df)
if 'df' in payload: if "df" in payload:
del payload['df'] del payload["df"]
return payload return payload
def get_payload(self): def get_payload(self):
@ -144,94 +145,94 @@ class QueryContext:
if self.datasource.cache_timeout is not None: if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout return self.datasource.cache_timeout
if ( if (
hasattr(self.datasource, 'database') and hasattr(self.datasource, "database")
self.datasource.database.cache_timeout) is not None: and self.datasource.database.cache_timeout
) is not None:
return self.datasource.database.cache_timeout return self.datasource.database.cache_timeout
return config.get('CACHE_DEFAULT_TIMEOUT') return config.get("CACHE_DEFAULT_TIMEOUT")
def get_df_payload(self, query_obj, **kwargs): def get_df_payload(self, query_obj, **kwargs):
"""Handles caching around the df paylod retrieval""" """Handles caching around the df paylod retrieval"""
cache_key = query_obj.cache_key( cache_key = (
datasource=self.datasource.uid, **kwargs) if query_obj else None query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
logging.info('Cache key: {}'.format(cache_key)) if query_obj
else None
)
logging.info("Cache key: {}".format(cache_key))
is_loaded = False is_loaded = False
stacktrace = None stacktrace = None
df = None df = None
cached_dttm = datetime.utcnow().isoformat().split('.')[0] cached_dttm = datetime.utcnow().isoformat().split(".")[0]
cache_value = None cache_value = None
status = None status = None
query = '' query = ""
error_message = None error_message = None
if cache_key and cache and not self.force: if cache_key and cache and not self.force:
cache_value = cache.get(cache_key) cache_value = cache.get(cache_key)
if cache_value: if cache_value:
stats_logger.incr('loaded_from_cache') stats_logger.incr("loaded_from_cache")
try: try:
cache_value = pkl.loads(cache_value) cache_value = pkl.loads(cache_value)
df = cache_value['df'] df = cache_value["df"]
query = cache_value['query'] query = cache_value["query"]
status = utils.QueryStatus.SUCCESS status = utils.QueryStatus.SUCCESS
is_loaded = True is_loaded = True
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
logging.error('Error reading cache: ' + logging.error(
utils.error_msg_from_exception(e)) "Error reading cache: " + utils.error_msg_from_exception(e)
logging.info('Serving from cache') )
logging.info("Serving from cache")
if query_obj and not is_loaded: if query_obj and not is_loaded:
try: try:
query_result = self.get_query_result(query_obj) query_result = self.get_query_result(query_obj)
status = query_result['status'] status = query_result["status"]
query = query_result['query'] query = query_result["query"]
error_message = query_result['error_message'] error_message = query_result["error_message"]
df = query_result['df'] df = query_result["df"]
if status != utils.QueryStatus.FAILED: if status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source') stats_logger.incr("loaded_from_source")
is_loaded = True is_loaded = True
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
if not error_message: if not error_message:
error_message = '{}'.format(e) error_message = "{}".format(e)
status = utils.QueryStatus.FAILED status = utils.QueryStatus.FAILED
stacktrace = traceback.format_exc() stacktrace = traceback.format_exc()
if ( if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED:
is_loaded and
cache_key and
cache and
status != utils.QueryStatus.FAILED):
try: try:
cache_value = dict( cache_value = dict(
dttm=cached_dttm, dttm=cached_dttm, df=df if df is not None else None, query=query
df=df if df is not None else None,
query=query,
) )
cache_binary = pkl.dumps( cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info('Caching {} chars at key {}'.format( logging.info(
len(cache_binary), cache_key)) "Caching {} chars at key {}".format(
len(cache_binary), cache_key
)
)
stats_logger.incr('set_cache_key') stats_logger.incr("set_cache_key")
cache.set( cache.set(
cache_key, cache_key, cache_value=cache_binary, timeout=self.cache_timeout
cache_value=cache_binary, )
timeout=self.cache_timeout)
except Exception as e: except Exception as e:
# cache.set call can fail if the backend is down or if # cache.set call can fail if the backend is down or if
# the key is too large or whatever other reasons # the key is too large or whatever other reasons
logging.warning('Could not cache key {}'.format(cache_key)) logging.warning("Could not cache key {}".format(cache_key))
logging.exception(e) logging.exception(e)
cache.delete(cache_key) cache.delete(cache_key)
return { return {
'cache_key': cache_key, "cache_key": cache_key,
'cached_dttm': cache_value['dttm'] if cache_value is not None else None, "cached_dttm": cache_value["dttm"] if cache_value is not None else None,
'cache_timeout': self.cache_timeout, "cache_timeout": self.cache_timeout,
'df': df, "df": df,
'error': error_message, "error": error_message,
'is_cached': cache_key is not None, "is_cached": cache_key is not None,
'query': query, "query": query,
'status': status, "status": status,
'stacktrace': stacktrace, "stacktrace": stacktrace,
'rowcount': len(df.index) if df is not None else 0, "rowcount": len(df.index) if df is not None else 0,
} }

View File

@ -27,6 +27,7 @@ from superset.utils import core as utils
# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
# https://github.com/python/mypy/issues/5288 # https://github.com/python/mypy/issues/5288
class QueryObject: class QueryObject:
""" """
The query object's schema matches the interfaces of DB connectors like sqla The query object's schema matches the interfaces of DB connectors like sqla
@ -34,25 +35,25 @@ class QueryObject:
""" """
def __init__( def __init__(
self, self,
granularity: str, granularity: str,
metrics: List[Union[Dict, str]], metrics: List[Union[Dict, str]],
groupby: List[str] = None, groupby: List[str] = None,
filters: List[str] = None, filters: List[str] = None,
time_range: Optional[str] = None, time_range: Optional[str] = None,
time_shift: Optional[str] = None, time_shift: Optional[str] = None,
is_timeseries: bool = False, is_timeseries: bool = False,
timeseries_limit: int = 0, timeseries_limit: int = 0,
row_limit: int = app.config.get('ROW_LIMIT'), row_limit: int = app.config.get("ROW_LIMIT"),
timeseries_limit_metric: Optional[Dict] = None, timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True, order_desc: bool = True,
extras: Optional[Dict] = None, extras: Optional[Dict] = None,
prequeries: Optional[List[Dict]] = None, prequeries: Optional[List[Dict]] = None,
is_prequery: bool = False, is_prequery: bool = False,
columns: List[str] = None, columns: List[str] = None,
orderby: List[List] = None, orderby: List[List] = None,
relative_start: str = app.config.get('DEFAULT_RELATIVE_START_TIME', 'today'), relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
relative_end: str = app.config.get('DEFAULT_RELATIVE_END_TIME', 'today'), relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"),
): ):
self.granularity = granularity self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until( self.from_dttm, self.to_dttm = utils.get_since_until(
@ -69,7 +70,7 @@ class QueryObject:
# Temporal solution for backward compatability issue # Temporal solution for backward compatability issue
# due the new format of non-ad-hoc metric. # due the new format of non-ad-hoc metric.
self.metrics = [ self.metrics = [
metric if 'expressionType' in metric else metric['label'] # noqa: T484 metric if "expressionType" in metric else metric["label"] # noqa: T484
for metric in metrics for metric in metrics
] ]
self.row_limit = row_limit self.row_limit = row_limit
@ -85,22 +86,22 @@ class QueryObject:
def to_dict(self): def to_dict(self):
query_object_dict = { query_object_dict = {
'granularity': self.granularity, "granularity": self.granularity,
'from_dttm': self.from_dttm, "from_dttm": self.from_dttm,
'to_dttm': self.to_dttm, "to_dttm": self.to_dttm,
'is_timeseries': self.is_timeseries, "is_timeseries": self.is_timeseries,
'groupby': self.groupby, "groupby": self.groupby,
'metrics': self.metrics, "metrics": self.metrics,
'row_limit': self.row_limit, "row_limit": self.row_limit,
'filter': self.filter, "filter": self.filter,
'timeseries_limit': self.timeseries_limit, "timeseries_limit": self.timeseries_limit,
'timeseries_limit_metric': self.timeseries_limit_metric, "timeseries_limit_metric": self.timeseries_limit_metric,
'order_desc': self.order_desc, "order_desc": self.order_desc,
'prequeries': self.prequeries, "prequeries": self.prequeries,
'is_prequery': self.is_prequery, "is_prequery": self.is_prequery,
'extras': self.extras, "extras": self.extras,
'columns': self.columns, "columns": self.columns,
'orderby': self.orderby, "orderby": self.orderby,
} }
return query_object_dict return query_object_dict
@ -115,17 +116,14 @@ class QueryObject:
cache_dict = self.to_dict() cache_dict = self.to_dict()
cache_dict.update(extra) cache_dict.update(extra)
for k in ['from_dttm', 'to_dttm']: for k in ["from_dttm", "to_dttm"]:
del cache_dict[k] del cache_dict[k]
if self.time_range: if self.time_range:
cache_dict['time_range'] = self.time_range cache_dict["time_range"] = self.time_range
json_data = self.json_dumps(cache_dict, sort_keys=True) json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode('utf-8')).hexdigest() return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def json_dumps(self, obj, sort_keys=False): def json_dumps(self, obj, sort_keys=False):
return json.dumps( return json.dumps(
obj, obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
default=utils.json_int_dttm_ser,
ignore_nan=True,
sort_keys=sort_keys,
) )

View File

@ -37,18 +37,18 @@ from superset.stats_logger import DummyStatsLogger
STATS_LOGGER = DummyStatsLogger() STATS_LOGGER = DummyStatsLogger()
BASE_DIR = os.path.abspath(os.path.dirname(__file__)) BASE_DIR = os.path.abspath(os.path.dirname(__file__))
if 'SUPERSET_HOME' in os.environ: if "SUPERSET_HOME" in os.environ:
DATA_DIR = os.environ['SUPERSET_HOME'] DATA_DIR = os.environ["SUPERSET_HOME"]
else: else:
DATA_DIR = os.path.join(os.path.expanduser('~'), '.superset') DATA_DIR = os.path.join(os.path.expanduser("~"), ".superset")
# --------------------------------------------------------- # ---------------------------------------------------------
# Superset specific config # Superset specific config
# --------------------------------------------------------- # ---------------------------------------------------------
PACKAGE_DIR = os.path.join(BASE_DIR, 'static', 'assets') PACKAGE_DIR = os.path.join(BASE_DIR, "static", "assets")
PACKAGE_FILE = os.path.join(PACKAGE_DIR, 'package.json') PACKAGE_FILE = os.path.join(PACKAGE_DIR, "package.json")
with open(PACKAGE_FILE) as package_file: with open(PACKAGE_FILE) as package_file:
VERSION_STRING = json.load(package_file)['version'] VERSION_STRING = json.load(package_file)["version"]
ROW_LIMIT = 50000 ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000 VIZ_ROW_LIMIT = 10000
@ -57,7 +57,7 @@ FILTER_SELECT_ROW_LIMIT = 10000
SUPERSET_WORKERS = 2 # deprecated SUPERSET_WORKERS = 2 # deprecated
SUPERSET_CELERY_WORKERS = 32 # deprecated SUPERSET_CELERY_WORKERS = 32 # deprecated
SUPERSET_WEBSERVER_ADDRESS = '0.0.0.0' SUPERSET_WEBSERVER_ADDRESS = "0.0.0.0"
SUPERSET_WEBSERVER_PORT = 8088 SUPERSET_WEBSERVER_PORT = 8088
# This is an important setting, and should be lower than your # This is an important setting, and should be lower than your
@ -73,10 +73,10 @@ SQLALCHEMY_TRACK_MODIFICATIONS = False
# --------------------------------------------------------- # ---------------------------------------------------------
# Your App secret key # Your App secret key
SECRET_KEY = '\2\1thisismyscretkey\1\2\e\y\y\h' # noqa SECRET_KEY = "\2\1thisismyscretkey\1\2\e\y\y\h" # noqa
# The SQLAlchemy connection string. # The SQLAlchemy connection string.
SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(DATA_DIR, 'superset.db') SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "superset.db")
# SQLALCHEMY_DATABASE_URI = 'mysql://myapp@localhost/myapp' # SQLALCHEMY_DATABASE_URI = 'mysql://myapp@localhost/myapp'
# SQLALCHEMY_DATABASE_URI = 'postgresql://root:password@localhost/myapp' # SQLALCHEMY_DATABASE_URI = 'postgresql://root:password@localhost/myapp'
@ -96,10 +96,10 @@ QUERY_SEARCH_LIMIT = 1000
WTF_CSRF_ENABLED = True WTF_CSRF_ENABLED = True
# Add endpoints that need to be exempt from CSRF protection # Add endpoints that need to be exempt from CSRF protection
WTF_CSRF_EXEMPT_LIST = ['superset.views.core.log'] WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log"]
# Whether to run the web server in debug mode or not # Whether to run the web server in debug mode or not
DEBUG = os.environ.get('FLASK_ENV') == 'development' DEBUG = os.environ.get("FLASK_ENV") == "development"
FLASK_USE_RELOAD = True FLASK_USE_RELOAD = True
# Whether to show the stacktrace on 500 error # Whether to show the stacktrace on 500 error
@ -112,10 +112,10 @@ ENABLE_PROXY_FIX = False
# GLOBALS FOR APP Builder # GLOBALS FOR APP Builder
# ------------------------------ # ------------------------------
# Uncomment to setup Your App name # Uncomment to setup Your App name
APP_NAME = 'Superset' APP_NAME = "Superset"
# Uncomment to setup an App icon # Uncomment to setup an App icon
APP_ICON = '/static/assets/images/superset-logo@2x.png' APP_ICON = "/static/assets/images/superset-logo@2x.png"
APP_ICON_WIDTH = 126 APP_ICON_WIDTH = 126
# Uncomment to specify where clicking the logo would take the user # Uncomment to specify where clicking the logo would take the user
@ -131,7 +131,7 @@ LOGO_TARGET_PATH = None
# other tz can be overridden by providing a local_config # other tz can be overridden by providing a local_config
DRUID_IS_ACTIVE = True DRUID_IS_ACTIVE = True
DRUID_TZ = tz.tzutc() DRUID_TZ = tz.tzutc()
DRUID_ANALYSIS_TYPES = ['cardinality'] DRUID_ANALYSIS_TYPES = ["cardinality"]
# ---------------------------------------------------- # ----------------------------------------------------
# AUTHENTICATION CONFIG # AUTHENTICATION CONFIG
@ -175,21 +175,21 @@ PUBLIC_ROLE_LIKE_GAMMA = False
# Babel config for translations # Babel config for translations
# --------------------------------------------------- # ---------------------------------------------------
# Setup default language # Setup default language
BABEL_DEFAULT_LOCALE = 'en' BABEL_DEFAULT_LOCALE = "en"
# Your application default translation path # Your application default translation path
BABEL_DEFAULT_FOLDER = 'superset/translations' BABEL_DEFAULT_FOLDER = "superset/translations"
# The allowed translation for you app # The allowed translation for you app
LANGUAGES = { LANGUAGES = {
'en': {'flag': 'us', 'name': 'English'}, "en": {"flag": "us", "name": "English"},
'it': {'flag': 'it', 'name': 'Italian'}, "it": {"flag": "it", "name": "Italian"},
'fr': {'flag': 'fr', 'name': 'French'}, "fr": {"flag": "fr", "name": "French"},
'zh': {'flag': 'cn', 'name': 'Chinese'}, "zh": {"flag": "cn", "name": "Chinese"},
'ja': {'flag': 'jp', 'name': 'Japanese'}, "ja": {"flag": "jp", "name": "Japanese"},
'de': {'flag': 'de', 'name': 'German'}, "de": {"flag": "de", "name": "German"},
'pt': {'flag': 'pt', 'name': 'Portuguese'}, "pt": {"flag": "pt", "name": "Portuguese"},
'pt_BR': {'flag': 'br', 'name': 'Brazilian Portuguese'}, "pt_BR": {"flag": "br", "name": "Brazilian Portuguese"},
'ru': {'flag': 'ru', 'name': 'Russian'}, "ru": {"flag": "ru", "name": "Russian"},
'ko': {'flag': 'kr', 'name': 'Korean'}, "ko": {"flag": "kr", "name": "Korean"},
} }
# --------------------------------------------------- # ---------------------------------------------------
@ -202,7 +202,7 @@ LANGUAGES = {
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True } # will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
DEFAULT_FEATURE_FLAGS = { DEFAULT_FEATURE_FLAGS = {
# Experimental feature introducing a client (browser) cache # Experimental feature introducing a client (browser) cache
'CLIENT_CACHE': False, "CLIENT_CACHE": False
} }
# A function that receives a dict of all feature flags # A function that receives a dict of all feature flags
@ -225,19 +225,19 @@ GET_FEATURE_FLAGS_FUNC = None
# Image and file configuration # Image and file configuration
# --------------------------------------------------- # ---------------------------------------------------
# The file upload folder, when using models with files # The file upload folder, when using models with files
UPLOAD_FOLDER = BASE_DIR + '/app/static/uploads/' UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/"
# The image upload folder, when using models with images # The image upload folder, when using models with images
IMG_UPLOAD_FOLDER = BASE_DIR + '/app/static/uploads/' IMG_UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/"
# The image upload url, when using models with images # The image upload url, when using models with images
IMG_UPLOAD_URL = '/static/uploads/' IMG_UPLOAD_URL = "/static/uploads/"
# Setup image size default is (300, 200, True) # Setup image size default is (300, 200, True)
# IMG_SIZE = (300, 200, True) # IMG_SIZE = (300, 200, True)
CACHE_DEFAULT_TIMEOUT = 60 * 60 * 24 CACHE_DEFAULT_TIMEOUT = 60 * 60 * 24
CACHE_CONFIG = {'CACHE_TYPE': 'null'} CACHE_CONFIG = {"CACHE_TYPE": "null"}
TABLE_NAMES_CACHE_CONFIG = {'CACHE_TYPE': 'null'} TABLE_NAMES_CACHE_CONFIG = {"CACHE_TYPE": "null"}
# CORS Options # CORS Options
ENABLE_CORS = False ENABLE_CORS = False
@ -252,13 +252,12 @@ SUPERSET_WEBSERVER_DOMAINS = None
# Allowed format types for upload on Database view # Allowed format types for upload on Database view
# TODO: Add processing of other spreadsheet formats (xls, xlsx etc) # TODO: Add processing of other spreadsheet formats (xls, xlsx etc)
ALLOWED_EXTENSIONS = set(['csv']) ALLOWED_EXTENSIONS = set(["csv"])
# CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv method # CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv
# method.
# note: index option should not be overridden # note: index option should not be overridden
CSV_EXPORT = { CSV_EXPORT = {"encoding": "utf-8"}
'encoding': 'utf-8',
}
# --------------------------------------------------- # ---------------------------------------------------
# Time grain configurations # Time grain configurations
@ -301,10 +300,12 @@ DRUID_DATA_SOURCE_BLACKLIST = []
# -------------------------------------------------- # --------------------------------------------------
# Modules, datasources and middleware to be registered # Modules, datasources and middleware to be registered
# -------------------------------------------------- # --------------------------------------------------
DEFAULT_MODULE_DS_MAP = OrderedDict([ DEFAULT_MODULE_DS_MAP = OrderedDict(
('superset.connectors.sqla.models', ['SqlaTable']), [
('superset.connectors.druid.models', ['DruidDatasource']), ("superset.connectors.sqla.models", ["SqlaTable"]),
]) ("superset.connectors.druid.models", ["DruidDatasource"]),
]
)
ADDITIONAL_MODULE_DS_MAP = {} ADDITIONAL_MODULE_DS_MAP = {}
ADDITIONAL_MIDDLEWARE = [] ADDITIONAL_MIDDLEWARE = []
@ -315,8 +316,8 @@ ADDITIONAL_MIDDLEWARE = []
# Console Log Settings # Console Log Settings
LOG_FORMAT = '%(asctime)s:%(levelname)s:%(name)s:%(message)s' LOG_FORMAT = "%(asctime)s:%(levelname)s:%(name)s:%(message)s"
LOG_LEVEL = 'DEBUG' LOG_LEVEL = "DEBUG"
# --------------------------------------------------- # ---------------------------------------------------
# Enable Time Rotate Log Handler # Enable Time Rotate Log Handler
@ -324,9 +325,9 @@ LOG_LEVEL = 'DEBUG'
# LOG_LEVEL = DEBUG, INFO, WARNING, ERROR, CRITICAL # LOG_LEVEL = DEBUG, INFO, WARNING, ERROR, CRITICAL
ENABLE_TIME_ROTATE = False ENABLE_TIME_ROTATE = False
TIME_ROTATE_LOG_LEVEL = 'DEBUG' TIME_ROTATE_LOG_LEVEL = "DEBUG"
FILENAME = os.path.join(DATA_DIR, 'superset.log') FILENAME = os.path.join(DATA_DIR, "superset.log")
ROLLOVER = 'midnight' ROLLOVER = "midnight"
INTERVAL = 1 INTERVAL = 1
BACKUP_COUNT = 30 BACKUP_COUNT = 30
@ -344,7 +345,7 @@ BACKUP_COUNT = 30
# pass # pass
# Set this API key to enable Mapbox visualizations # Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get('MAPBOX_API_KEY', '') MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
# Maximum number of rows returned from a database # Maximum number of rows returned from a database
# in async mode, no more than SQL_MAX_ROW will be returned and stored # in async mode, no more than SQL_MAX_ROW will be returned and stored
@ -378,31 +379,26 @@ WARNING_MSG = None
class CeleryConfig(object): class CeleryConfig(object):
BROKER_URL = 'sqla+sqlite:///celerydb.sqlite' BROKER_URL = "sqla+sqlite:///celerydb.sqlite"
CELERY_IMPORTS = ( CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks")
'superset.sql_lab', CELERY_RESULT_BACKEND = "db+sqlite:///celery_results.sqlite"
'superset.tasks', CELERYD_LOG_LEVEL = "DEBUG"
)
CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite'
CELERYD_LOG_LEVEL = 'DEBUG'
CELERYD_PREFETCH_MULTIPLIER = 1 CELERYD_PREFETCH_MULTIPLIER = 1
CELERY_ACKS_LATE = True CELERY_ACKS_LATE = True
CELERY_ANNOTATIONS = { CELERY_ANNOTATIONS = {
'sql_lab.get_sql_results': { "sql_lab.get_sql_results": {"rate_limit": "100/s"},
'rate_limit': '100/s', "email_reports.send": {
}, "rate_limit": "1/s",
'email_reports.send': { "time_limit": 120,
'rate_limit': '1/s', "soft_time_limit": 150,
'time_limit': 120, "ignore_result": True,
'soft_time_limit': 150,
'ignore_result': True,
}, },
} }
CELERYBEAT_SCHEDULE = { CELERYBEAT_SCHEDULE = {
'email_reports.schedule_hourly': { "email_reports.schedule_hourly": {
'task': 'email_reports.schedule_hourly', "task": "email_reports.schedule_hourly",
'schedule': crontab(minute=1, hour='*'), "schedule": crontab(minute=1, hour="*"),
}, }
} }
@ -444,7 +440,7 @@ CSV_TO_HIVE_UPLOAD_S3_BUCKET = None
# The directory within the bucket specified above that will # The directory within the bucket specified above that will
# contain all the external tables # contain all the external tables
CSV_TO_HIVE_UPLOAD_DIRECTORY = 'EXTERNAL_HIVE_TABLES/' CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
# The namespace within hive where the tables created from # The namespace within hive where the tables created from
# uploading CSVs will be stored. # uploading CSVs will be stored.
@ -458,9 +454,9 @@ JINJA_CONTEXT_ADDONS = {}
# Roles that are controlled by the API / Superset and should not be changes # Roles that are controlled by the API / Superset and should not be changes
# by humans. # by humans.
ROBOT_PERMISSION_ROLES = ['Public', 'Gamma', 'Alpha', 'Admin', 'sql_lab'] ROBOT_PERMISSION_ROLES = ["Public", "Gamma", "Alpha", "Admin", "sql_lab"]
CONFIG_PATH_ENV_VAR = 'SUPERSET_CONFIG_PATH' CONFIG_PATH_ENV_VAR = "SUPERSET_CONFIG_PATH"
# If a callable is specified, it will be called at app startup while passing # If a callable is specified, it will be called at app startup while passing
# a reference to the Flask app. This can be used to alter the Flask app # a reference to the Flask app. This can be used to alter the Flask app
@ -474,16 +470,16 @@ ENABLE_ACCESS_REQUEST = False
# smtp server configuration # smtp server configuration
EMAIL_NOTIFICATIONS = False # all the emails are sent using dryrun EMAIL_NOTIFICATIONS = False # all the emails are sent using dryrun
SMTP_HOST = 'localhost' SMTP_HOST = "localhost"
SMTP_STARTTLS = True SMTP_STARTTLS = True
SMTP_SSL = False SMTP_SSL = False
SMTP_USER = 'superset' SMTP_USER = "superset"
SMTP_PORT = 25 SMTP_PORT = 25
SMTP_PASSWORD = 'superset' SMTP_PASSWORD = "superset"
SMTP_MAIL_FROM = 'superset@superset.com' SMTP_MAIL_FROM = "superset@superset.com"
if not CACHE_DEFAULT_TIMEOUT: if not CACHE_DEFAULT_TIMEOUT:
CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT') CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get("CACHE_DEFAULT_TIMEOUT")
# Whether to bump the logging level to ERROR on the flask_appbuilder package # Whether to bump the logging level to ERROR on the flask_appbuilder package
# Set to False if/when debugging FAB related issues like # Set to False if/when debugging FAB related issues like
@ -492,14 +488,14 @@ SILENCE_FAB = True
# The link to a page containing common errors and their resolutions # The link to a page containing common errors and their resolutions
# It will be appended at the bottom of sql_lab errors. # It will be appended at the bottom of sql_lab errors.
TROUBLESHOOTING_LINK = '' TROUBLESHOOTING_LINK = ""
# CSRF token timeout, set to None for a token that never expires # CSRF token timeout, set to None for a token that never expires
WTF_CSRF_TIME_LIMIT = 60 * 60 * 24 * 7 WTF_CSRF_TIME_LIMIT = 60 * 60 * 24 * 7
# This link should lead to a page with instructions on how to gain access to a # This link should lead to a page with instructions on how to gain access to a
# Datasource. It will be placed at the bottom of permissions errors. # Datasource. It will be placed at the bottom of permissions errors.
PERMISSION_INSTRUCTIONS_LINK = '' PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your # Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app # configuration. These blueprints will get integrated in the app
@ -565,7 +561,7 @@ EMAIL_REPORTS_CRON_RESOLUTION = 15
# Email report configuration # Email report configuration
# From address in emails # From address in emails
EMAIL_REPORT_FROM_ADDRESS = 'reports@superset.org' EMAIL_REPORT_FROM_ADDRESS = "reports@superset.org"
# Send bcc of all reports to this address. Set to None to disable. # Send bcc of all reports to this address. Set to None to disable.
# This is useful for maintaining an audit trail of all email deliveries. # This is useful for maintaining an audit trail of all email deliveries.
@ -575,8 +571,8 @@ EMAIL_REPORT_BCC_ADDRESS = None
# This user should have permissions to browse all the dashboards and # This user should have permissions to browse all the dashboards and
# slices. # slices.
# TODO: In the future, login as the owner of the item to generate reports # TODO: In the future, login as the owner of the item to generate reports
EMAIL_REPORTS_USER = 'admin' EMAIL_REPORTS_USER = "admin"
EMAIL_REPORTS_SUBJECT_PREFIX = '[Report] ' EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
# The webdriver to use for generating reports. Use one of the following # The webdriver to use for generating reports. Use one of the following
# firefox # firefox
@ -585,19 +581,16 @@ EMAIL_REPORTS_SUBJECT_PREFIX = '[Report] '
# chrome: # chrome:
# Requires: headless chrome # Requires: headless chrome
# Limitations: unable to generate screenshots of elements # Limitations: unable to generate screenshots of elements
EMAIL_REPORTS_WEBDRIVER = 'firefox' EMAIL_REPORTS_WEBDRIVER = "firefox"
# Window size - this will impact the rendering of the data # Window size - this will impact the rendering of the data
WEBDRIVER_WINDOW = { WEBDRIVER_WINDOW = {"dashboard": (1600, 2000), "slice": (3000, 1200)}
'dashboard': (1600, 2000),
'slice': (3000, 1200),
}
# Any config options to be passed as-is to the webdriver # Any config options to be passed as-is to the webdriver
WEBDRIVER_CONFIGURATION = {} WEBDRIVER_CONFIGURATION = {}
# The base URL to query for accessing the user interface # The base URL to query for accessing the user interface
WEBDRIVER_BASEURL = 'http://0.0.0.0:8080/' WEBDRIVER_BASEURL = "http://0.0.0.0:8080/"
# Send user to a link where they can report bugs # Send user to a link where they can report bugs
BUG_REPORT_URL = None BUG_REPORT_URL = None
@ -611,33 +604,34 @@ DOCUMENTATION_URL = None
# filter a moving window. By only setting the end time to now, # filter a moving window. By only setting the end time to now,
# start time will be set to midnight, while end will be relative to # start time will be set to midnight, while end will be relative to
# the query issue time. # the query issue time.
DEFAULT_RELATIVE_START_TIME = 'today' DEFAULT_RELATIVE_START_TIME = "today"
DEFAULT_RELATIVE_END_TIME = 'today' DEFAULT_RELATIVE_END_TIME = "today"
# Configure which SQL validator to use for each engine # Configure which SQL validator to use for each engine
SQL_VALIDATORS_BY_ENGINE = { SQL_VALIDATORS_BY_ENGINE = {"presto": "PrestoDBSQLValidator"}
'presto': 'PrestoDBSQLValidator',
}
# Do you want Talisman enabled? # Do you want Talisman enabled?
TALISMAN_ENABLED = False TALISMAN_ENABLED = False
# If you want Talisman, how do you want it configured?? # If you want Talisman, how do you want it configured??
TALISMAN_CONFIG = { TALISMAN_CONFIG = {
'content_security_policy': None, "content_security_policy": None,
'force_https': True, "force_https": True,
'force_https_permanent': False, "force_https_permanent": False,
} }
try: try:
if CONFIG_PATH_ENV_VAR in os.environ: if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not in pythonpath; useful # Explicitly import config module that is not in pythonpath; useful
# for case where app is being executed via pex. # for case where app is being executed via pex.
print('Loaded your LOCAL configuration at [{}]'.format( print(
os.environ[CONFIG_PATH_ENV_VAR])) "Loaded your LOCAL configuration at [{}]".format(
os.environ[CONFIG_PATH_ENV_VAR]
)
)
module = sys.modules[__name__] module = sys.modules[__name__]
override_conf = imp.load_source( override_conf = imp.load_source(
'superset_config', "superset_config", os.environ[CONFIG_PATH_ENV_VAR]
os.environ[CONFIG_PATH_ENV_VAR]) )
for key in dir(override_conf): for key in dir(override_conf):
if key.isupper(): if key.isupper():
setattr(module, key, getattr(override_conf, key)) setattr(module, key, getattr(override_conf, key))
@ -645,7 +639,9 @@ try:
else: else:
from superset_config import * # noqa from superset_config import * # noqa
import superset_config import superset_config
print('Loaded your LOCAL configuration at [{}]'.format(
superset_config.__file__)) print(
"Loaded your LOCAL configuration at [{}]".format(superset_config.__file__)
)
except ImportError: except ImportError:
pass pass

View File

@ -17,9 +17,7 @@
# pylint: disable=C,R,W # pylint: disable=C,R,W
import json import json
from sqlalchemy import ( from sqlalchemy import and_, Boolean, Column, Integer, String, Text
and_, Boolean, Column, Integer, String, Text,
)
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, relationship from sqlalchemy.orm import foreign, relationship
@ -67,7 +65,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@declared_attr @declared_attr
def slices(self): def slices(self):
return relationship( return relationship(
'Slice', "Slice",
primaryjoin=lambda: and_( primaryjoin=lambda: and_(
foreign(Slice.datasource_id) == self.id, foreign(Slice.datasource_id) == self.id,
foreign(Slice.datasource_type) == self.type, foreign(Slice.datasource_type) == self.type,
@ -82,11 +80,11 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property @property
def uid(self): def uid(self):
"""Unique id across datasource types""" """Unique id across datasource types"""
return f'{self.id}__{self.type}' return f"{self.id}__{self.type}"
@property @property
def column_names(self): def column_names(self):
return sorted([c.column_name for c in self.columns], key=lambda x: x or '') return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property @property
def columns_types(self): def columns_types(self):
@ -94,7 +92,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property @property
def main_dttm_col(self): def main_dttm_col(self):
return 'timestamp' return "timestamp"
@property @property
def datasource_name(self): def datasource_name(self):
@ -120,22 +118,18 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property @property
def url(self): def url(self):
return '/{}/edit/{}'.format(self.baselink, self.id) return "/{}/edit/{}".format(self.baselink, self.id)
@property @property
def explore_url(self): def explore_url(self):
if self.default_endpoint: if self.default_endpoint:
return self.default_endpoint return self.default_endpoint
else: else:
return '/superset/explore/{obj.type}/{obj.id}/'.format(obj=self) return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property @property
def column_formats(self): def column_formats(self):
return { return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
m.metric_name: m.d3format
for m in self.metrics
if m.d3format
}
def add_missing_metrics(self, metrics): def add_missing_metrics(self, metrics):
exisiting_metrics = {m.metric_name for m in self.metrics} exisiting_metrics = {m.metric_name for m in self.metrics}
@ -148,14 +142,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
def short_data(self): def short_data(self):
"""Data representation of the datasource sent to the frontend""" """Data representation of the datasource sent to the frontend"""
return { return {
'edit_url': self.url, "edit_url": self.url,
'id': self.id, "id": self.id,
'uid': self.uid, "uid": self.uid,
'schema': self.schema, "schema": self.schema,
'name': self.name, "name": self.name,
'type': self.type, "type": self.type,
'connection': self.connection, "connection": self.connection,
'creator': str(self.created_by), "creator": str(self.created_by),
} }
@property @property
@ -168,68 +162,65 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
order_by_choices = [] order_by_choices = []
# self.column_names return sorted column_names # self.column_names return sorted column_names
for s in self.column_names: for s in self.column_names:
s = str(s or '') s = str(s or "")
order_by_choices.append((json.dumps([s, True]), s + ' [asc]')) order_by_choices.append((json.dumps([s, True]), s + " [asc]"))
order_by_choices.append((json.dumps([s, False]), s + ' [desc]')) order_by_choices.append((json.dumps([s, False]), s + " [desc]"))
verbose_map = {'__timestamp': 'Time'} verbose_map = {"__timestamp": "Time"}
verbose_map.update({ verbose_map.update(
o.metric_name: o.verbose_name or o.metric_name {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
for o in self.metrics )
}) verbose_map.update(
verbose_map.update({ {o.column_name: o.verbose_name or o.column_name for o in self.columns}
o.column_name: o.verbose_name or o.column_name )
for o in self.columns
})
return { return {
# simple fields # simple fields
'id': self.id, "id": self.id,
'column_formats': self.column_formats, "column_formats": self.column_formats,
'description': self.description, "description": self.description,
'database': self.database.data, # pylint: disable=no-member "database": self.database.data, # pylint: disable=no-member
'default_endpoint': self.default_endpoint, "default_endpoint": self.default_endpoint,
'filter_select': self.filter_select_enabled, # TODO deprecate "filter_select": self.filter_select_enabled, # TODO deprecate
'filter_select_enabled': self.filter_select_enabled, "filter_select_enabled": self.filter_select_enabled,
'name': self.name, "name": self.name,
'datasource_name': self.datasource_name, "datasource_name": self.datasource_name,
'type': self.type, "type": self.type,
'schema': self.schema, "schema": self.schema,
'offset': self.offset, "offset": self.offset,
'cache_timeout': self.cache_timeout, "cache_timeout": self.cache_timeout,
'params': self.params, "params": self.params,
'perm': self.perm, "perm": self.perm,
'edit_url': self.url, "edit_url": self.url,
# sqla-specific # sqla-specific
'sql': self.sql, "sql": self.sql,
# one to many # one to many
'columns': [o.data for o in self.columns], "columns": [o.data for o in self.columns],
'metrics': [o.data for o in self.metrics], "metrics": [o.data for o in self.metrics],
# TODO deprecate, move logic to JS # TODO deprecate, move logic to JS
'order_by_choices': order_by_choices, "order_by_choices": order_by_choices,
'owners': [owner.id for owner in self.owners], "owners": [owner.id for owner in self.owners],
'verbose_map': verbose_map, "verbose_map": verbose_map,
'select_star': self.select_star, "select_star": self.select_star,
} }
@staticmethod @staticmethod
def filter_values_handler( def filter_values_handler(
values, target_column_is_numeric=False, is_list_target=False): values, target_column_is_numeric=False, is_list_target=False
):
def handle_single_value(v): def handle_single_value(v):
# backward compatibility with previous <select> components # backward compatibility with previous <select> components
if isinstance(v, str): if isinstance(v, str):
v = v.strip('\t\n\'"') v = v.strip("\t\n'\"")
if target_column_is_numeric: if target_column_is_numeric:
# For backwards compatibility and edge cases # For backwards compatibility and edge cases
# where a column data type might have changed # where a column data type might have changed
v = utils.string_to_num(v) v = utils.string_to_num(v)
if v == '<NULL>': if v == "<NULL>":
return None return None
elif v == '<empty string>': elif v == "<empty string>":
return '' return ""
return v return v
if isinstance(values, (list, tuple)): if isinstance(values, (list, tuple)):
values = [handle_single_value(v) for v in values] values = [handle_single_value(v) for v in values]
else: else:
@ -278,8 +269,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
if col.column_name == column_name: if col.column_name == column_name:
return col return col
def get_fk_many_from_list( def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr):
self, object_list, fkmany, fkmany_class, key_attr):
"""Update ORM one-to-many list from object list """Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code""" Used for syncing metrics and columns using the same code"""
@ -302,13 +292,10 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
for obj in object_list: for obj in object_list:
key = obj.get(key_attr) key = obj.get(key_attr)
if key not in orm_keys: if key not in orm_keys:
del obj['id'] del obj["id"]
orm_kwargs = {} orm_kwargs = {}
for k in obj: for k in obj:
if ( if k in fkmany_class.update_from_object_fields and k in obj:
k in fkmany_class.update_from_object_fields and
k in obj
):
orm_kwargs[k] = obj[k] orm_kwargs[k] = obj[k]
new_obj = fkmany_class(**orm_kwargs) new_obj = fkmany_class(**orm_kwargs)
new_fks.append(new_obj) new_fks.append(new_obj)
@ -329,16 +316,18 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
for attr in self.update_from_object_fields: for attr in self.update_from_object_fields:
setattr(self, attr, obj.get(attr)) setattr(self, attr, obj.get(attr))
self.owners = obj.get('owners', []) self.owners = obj.get("owners", [])
# Syncing metrics # Syncing metrics
metrics = self.get_fk_many_from_list( metrics = self.get_fk_many_from_list(
obj.get('metrics'), self.metrics, self.metric_class, 'metric_name') obj.get("metrics"), self.metrics, self.metric_class, "metric_name"
)
self.metrics = metrics self.metrics = metrics
# Syncing columns # Syncing columns
self.columns = self.get_fk_many_from_list( self.columns = self.get_fk_many_from_list(
obj.get('columns'), self.columns, self.column_class, 'column_name') obj.get("columns"), self.columns, self.column_class, "column_name"
)
class BaseColumn(AuditMixinNullable, ImportMixin): class BaseColumn(AuditMixinNullable, ImportMixin):
@ -363,32 +352,31 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
return self.column_name return self.column_name
num_types = ( num_types = (
'DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'NUMBER', "DOUBLE",
'LONG', 'REAL', 'NUMERIC', 'DECIMAL', 'MONEY', "FLOAT",
"INT",
"BIGINT",
"NUMBER",
"LONG",
"REAL",
"NUMERIC",
"DECIMAL",
"MONEY",
) )
date_types = ('DATE', 'TIME', 'DATETIME') date_types = ("DATE", "TIME", "DATETIME")
str_types = ('VARCHAR', 'STRING', 'CHAR') str_types = ("VARCHAR", "STRING", "CHAR")
@property @property
def is_num(self): def is_num(self):
return ( return self.type and any([t in self.type.upper() for t in self.num_types])
self.type and
any([t in self.type.upper() for t in self.num_types])
)
@property @property
def is_time(self): def is_time(self):
return ( return self.type and any([t in self.type.upper() for t in self.date_types])
self.type and
any([t in self.type.upper() for t in self.date_types])
)
@property @property
def is_string(self): def is_string(self):
return ( return self.type and any([t in self.type.upper() for t in self.str_types])
self.type and
any([t in self.type.upper() for t in self.str_types])
)
@property @property
def expression(self): def expression(self):
@ -397,9 +385,17 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
@property @property
def data(self): def data(self):
attrs = ( attrs = (
'id', 'column_name', 'verbose_name', 'description', 'expression', "id",
'filterable', 'groupby', 'is_dttm', 'type', "column_name",
'database_expression', 'python_date_format', "verbose_name",
"description",
"expression",
"filterable",
"groupby",
"is_dttm",
"type",
"database_expression",
"python_date_format",
) )
return {s: getattr(self, s) for s in attrs if hasattr(self, s)} return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
@ -432,6 +428,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
backref=backref('metrics', cascade='all, delete-orphan'), backref=backref('metrics', cascade='all, delete-orphan'),
enable_typechecks=False) enable_typechecks=False)
""" """
@property @property
def perm(self): def perm(self):
raise NotImplementedError() raise NotImplementedError()
@ -443,6 +440,12 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
@property @property
def data(self): def data(self):
attrs = ( attrs = (
'id', 'metric_name', 'verbose_name', 'description', 'expression', "id",
'warning_text', 'd3format') "metric_name",
"verbose_name",
"description",
"expression",
"warning_text",
"d3format",
)
return {s: getattr(self, s) for s in attrs} return {s: getattr(self, s) for s in attrs}

View File

@ -24,7 +24,10 @@ from superset.views.base import SupersetModelView
class DatasourceModelView(SupersetModelView): class DatasourceModelView(SupersetModelView):
def pre_delete(self, obj): def pre_delete(self, obj):
if obj.slices: if obj.slices:
raise SupersetException(Markup( raise SupersetException(
'Cannot delete a datasource that has slices attached to it.' Markup(
"Here's the list of associated charts: " + "Cannot delete a datasource that has slices attached to it."
''.join([o.slice_name for o in obj.slices]))) "Here's the list of associated charts: "
+ "".join([o.slice_name for o in obj.slices])
)
)

View File

@ -51,15 +51,21 @@ class ConnectorRegistry(object):
return datasources return datasources
@classmethod @classmethod
def get_datasource_by_name(cls, session, datasource_type, datasource_name, def get_datasource_by_name(
schema, database_name): cls, session, datasource_type, datasource_name, schema, database_name
):
datasource_class = ConnectorRegistry.sources[datasource_type] datasource_class = ConnectorRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all() datasources = session.query(datasource_class).all()
# Filter datasoures that don't have database. # Filter datasoures that don't have database.
db_ds = [d for d in datasources if d.database and db_ds = [
d.database.name == database_name and d
d.name == datasource_name and schema == schema] for d in datasources
if d.database
and d.database.name == database_name
and d.name == datasource_name
and schema == schema
]
return db_ds[0] return db_ds[0]
@classmethod @classmethod
@ -87,8 +93,8 @@ class ConnectorRegistry(object):
) )
@classmethod @classmethod
def query_datasources_by_name( def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
cls, session, database, datasource_name, schema=None):
datasource_class = ConnectorRegistry.sources[database.type] datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name( return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=None) session, database, datasource_name, schema=None
)

File diff suppressed because it is too large Load Diff

View File

@ -33,9 +33,14 @@ from superset.connectors.base.views import DatasourceModelView
from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.base import ( from superset.views.base import (
BaseSupersetView, DatasourceFilter, DeleteMixin, BaseSupersetView,
get_datasource_exist_error_msg, ListWidgetWithCheckboxes, SupersetModelView, DatasourceFilter,
validate_json, YamlExportMixin, DeleteMixin,
get_datasource_exist_error_msg,
ListWidgetWithCheckboxes,
SupersetModelView,
validate_json,
YamlExportMixin,
) )
from . import models from . import models
@ -43,48 +48,56 @@ from . import models
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidColumn) datamodel = SQLAInterface(models.DruidColumn)
list_title = _('Columns') list_title = _("Columns")
show_title = _('Show Druid Column') show_title = _("Show Druid Column")
add_title = _('Add Druid Column') add_title = _("Add Druid Column")
edit_title = _('Edit Druid Column') edit_title = _("Edit Druid Column")
list_widget = ListWidgetWithCheckboxes list_widget = ListWidgetWithCheckboxes
edit_columns = [ edit_columns = [
'column_name', 'verbose_name', 'description', 'dimension_spec_json', 'datasource', "column_name",
'groupby', 'filterable'] "verbose_name",
"description",
"dimension_spec_json",
"datasource",
"groupby",
"filterable",
]
add_columns = edit_columns add_columns = edit_columns
list_columns = ['column_name', 'verbose_name', 'type', 'groupby', 'filterable'] list_columns = ["column_name", "verbose_name", "type", "groupby", "filterable"]
can_delete = False can_delete = False
page_size = 500 page_size = 500
label_columns = { label_columns = {
'column_name': _('Column'), "column_name": _("Column"),
'type': _('Type'), "type": _("Type"),
'datasource': _('Datasource'), "datasource": _("Datasource"),
'groupby': _('Groupable'), "groupby": _("Groupable"),
'filterable': _('Filterable'), "filterable": _("Filterable"),
} }
description_columns = { description_columns = {
'filterable': _( "filterable": _(
'Whether this column is exposed in the `Filters` section ' "Whether this column is exposed in the `Filters` section "
'of the explore view.'), "of the explore view."
'dimension_spec_json': utils.markdown( ),
'this field can be used to specify ' "dimension_spec_json": utils.markdown(
'a `dimensionSpec` as documented [here]' "this field can be used to specify "
'(http://druid.io/docs/latest/querying/dimensionspecs.html). ' "a `dimensionSpec` as documented [here]"
'Make sure to input valid JSON and that the ' "(http://druid.io/docs/latest/querying/dimensionspecs.html). "
'`outputName` matches the `column_name` defined ' "Make sure to input valid JSON and that the "
'above.', "`outputName` matches the `column_name` defined "
True), "above.",
True,
),
} }
add_form_extra_fields = { add_form_extra_fields = {
'datasource': QuerySelectField( "datasource": QuerySelectField(
'Datasource', "Datasource",
query_factory=lambda: db.session().query(models.DruidDatasource), query_factory=lambda: db.session().query(models.DruidDatasource),
allow_blank=True, allow_blank=True,
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
edit_form_extra_fields = add_form_extra_fields edit_form_extra_fields = add_form_extra_fields
@ -96,18 +109,20 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
try: try:
dimension_spec = json.loads(col.dimension_spec_json) dimension_spec = json.loads(col.dimension_spec_json)
except ValueError as e: except ValueError as e:
raise ValueError('Invalid Dimension Spec JSON: ' + str(e)) raise ValueError("Invalid Dimension Spec JSON: " + str(e))
if not isinstance(dimension_spec, dict): if not isinstance(dimension_spec, dict):
raise ValueError('Dimension Spec must be a JSON object') raise ValueError("Dimension Spec must be a JSON object")
if 'outputName' not in dimension_spec: if "outputName" not in dimension_spec:
raise ValueError('Dimension Spec does not contain `outputName`') raise ValueError("Dimension Spec does not contain `outputName`")
if 'dimension' not in dimension_spec: if "dimension" not in dimension_spec:
raise ValueError('Dimension Spec is missing `dimension`') raise ValueError("Dimension Spec is missing `dimension`")
# `outputName` should be the same as the `column_name` # `outputName` should be the same as the `column_name`
if dimension_spec['outputName'] != col.column_name: if dimension_spec["outputName"] != col.column_name:
raise ValueError( raise ValueError(
'`outputName` [{}] unequal to `column_name` [{}]' "`outputName` [{}] unequal to `column_name` [{}]".format(
.format(dimension_spec['outputName'], col.column_name)) dimension_spec["outputName"], col.column_name
)
)
def post_update(self, col): def post_update(self, col):
col.refresh_metrics() col.refresh_metrics()
@ -122,60 +137,73 @@ appbuilder.add_view_no_menu(DruidColumnInlineView)
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidMetric) datamodel = SQLAInterface(models.DruidMetric)
list_title = _('Metrics') list_title = _("Metrics")
show_title = _('Show Druid Metric') show_title = _("Show Druid Metric")
add_title = _('Add Druid Metric') add_title = _("Add Druid Metric")
edit_title = _('Edit Druid Metric') edit_title = _("Edit Druid Metric")
list_columns = ['metric_name', 'verbose_name', 'metric_type'] list_columns = ["metric_name", "verbose_name", "metric_type"]
edit_columns = [ edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type', 'json', "metric_name",
'datasource', 'd3format', 'is_restricted', 'warning_text'] "description",
"verbose_name",
"metric_type",
"json",
"datasource",
"d3format",
"is_restricted",
"warning_text",
]
add_columns = edit_columns add_columns = edit_columns
page_size = 500 page_size = 500
validators_columns = { validators_columns = {"json": [validate_json]}
'json': [validate_json],
}
description_columns = { description_columns = {
'metric_type': utils.markdown( "metric_type": utils.markdown(
'use `postagg` as the metric type if you are defining a ' "use `postagg` as the metric type if you are defining a "
'[Druid Post Aggregation]' "[Druid Post Aggregation]"
'(http://druid.io/docs/latest/querying/post-aggregations.html)', "(http://druid.io/docs/latest/querying/post-aggregations.html)",
True), True,
'is_restricted': _('Whether access to this metric is restricted ' ),
'to certain roles. Only roles with the permission ' "is_restricted": _(
"'metric access on XXX (the name of this metric)' " "Whether access to this metric is restricted "
'are allowed to access this metric'), "to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"
),
} }
label_columns = { label_columns = {
'metric_name': _('Metric'), "metric_name": _("Metric"),
'description': _('Description'), "description": _("Description"),
'verbose_name': _('Verbose Name'), "verbose_name": _("Verbose Name"),
'metric_type': _('Type'), "metric_type": _("Type"),
'json': _('JSON'), "json": _("JSON"),
'datasource': _('Druid Datasource'), "datasource": _("Druid Datasource"),
'warning_text': _('Warning Message'), "warning_text": _("Warning Message"),
'is_restricted': _('Is Restricted'), "is_restricted": _("Is Restricted"),
} }
add_form_extra_fields = { add_form_extra_fields = {
'datasource': QuerySelectField( "datasource": QuerySelectField(
'Datasource', "Datasource",
query_factory=lambda: db.session().query(models.DruidDatasource), query_factory=lambda: db.session().query(models.DruidDatasource),
allow_blank=True, allow_blank=True,
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
edit_form_extra_fields = add_form_extra_fields edit_form_extra_fields = add_form_extra_fields
def post_add(self, metric): def post_add(self, metric):
if metric.is_restricted: if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm()) security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
def post_update(self, metric): def post_update(self, metric):
if metric.is_restricted: if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm()) security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
appbuilder.add_view_no_menu(DruidMetricInlineView) appbuilder.add_view_no_menu(DruidMetricInlineView)
@ -184,57 +212,63 @@ appbuilder.add_view_no_menu(DruidMetricInlineView)
class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.DruidCluster) datamodel = SQLAInterface(models.DruidCluster)
list_title = _('Druid Clusters') list_title = _("Druid Clusters")
show_title = _('Show Druid Cluster') show_title = _("Show Druid Cluster")
add_title = _('Add Druid Cluster') add_title = _("Add Druid Cluster")
edit_title = _('Edit Druid Cluster') edit_title = _("Edit Druid Cluster")
add_columns = [ add_columns = [
'verbose_name', 'broker_host', 'broker_port', "verbose_name",
'broker_user', 'broker_pass', 'broker_endpoint', "broker_host",
'cache_timeout', 'cluster_name', "broker_port",
"broker_user",
"broker_pass",
"broker_endpoint",
"cache_timeout",
"cluster_name",
] ]
edit_columns = add_columns edit_columns = add_columns
list_columns = ['cluster_name', 'metadata_last_refreshed'] list_columns = ["cluster_name", "metadata_last_refreshed"]
search_columns = ('cluster_name',) search_columns = ("cluster_name",)
label_columns = { label_columns = {
'cluster_name': _('Cluster'), "cluster_name": _("Cluster"),
'broker_host': _('Broker Host'), "broker_host": _("Broker Host"),
'broker_port': _('Broker Port'), "broker_port": _("Broker Port"),
'broker_user': _('Broker Username'), "broker_user": _("Broker Username"),
'broker_pass': _('Broker Password'), "broker_pass": _("Broker Password"),
'broker_endpoint': _('Broker Endpoint'), "broker_endpoint": _("Broker Endpoint"),
'verbose_name': _('Verbose Name'), "verbose_name": _("Verbose Name"),
'cache_timeout': _('Cache Timeout'), "cache_timeout": _("Cache Timeout"),
'metadata_last_refreshed': _('Metadata Last Refreshed'), "metadata_last_refreshed": _("Metadata Last Refreshed"),
} }
description_columns = { description_columns = {
'cache_timeout': _( "cache_timeout": _(
'Duration (in seconds) of the caching timeout for this cluster. ' "Duration (in seconds) of the caching timeout for this cluster. "
'A timeout of 0 indicates that the cache never expires. ' "A timeout of 0 indicates that the cache never expires. "
'Note this defaults to the global timeout if undefined.'), "Note this defaults to the global timeout if undefined."
'broker_user': _(
'Druid supports basic authentication. See '
'[auth](http://druid.io/docs/latest/design/auth.html) and '
'druid-basic-security extension',
), ),
'broker_pass': _( "broker_user": _(
'Druid supports basic authentication. See ' "Druid supports basic authentication. See "
'[auth](http://druid.io/docs/latest/design/auth.html) and ' "[auth](http://druid.io/docs/latest/design/auth.html) and "
'druid-basic-security extension', "druid-basic-security extension"
),
"broker_pass": _(
"Druid supports basic authentication. See "
"[auth](http://druid.io/docs/latest/design/auth.html) and "
"druid-basic-security extension"
), ),
} }
edit_form_extra_fields = { edit_form_extra_fields = {
'cluster_name': QuerySelectField( "cluster_name": QuerySelectField(
'Cluster', "Cluster",
query_factory=lambda: db.session().query(models.DruidCluster), query_factory=lambda: db.session().query(models.DruidCluster),
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
def pre_add(self, cluster): def pre_add(self, cluster):
security_manager.add_permission_view_menu('database_access', cluster.perm) security_manager.add_permission_view_menu("database_access", cluster.perm)
def pre_update(self, cluster): def pre_update(self, cluster):
self.pre_add(cluster) self.pre_add(cluster)
@ -245,112 +279,118 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): #
appbuilder.add_view( appbuilder.add_view(
DruidClusterModelView, DruidClusterModelView,
name='Druid Clusters', name="Druid Clusters",
label=__('Druid Clusters'), label=__("Druid Clusters"),
icon='fa-cubes', icon="fa-cubes",
category='Sources', category="Sources",
category_label=__('Sources'), category_label=__("Sources"),
category_icon='fa-database', category_icon="fa-database",
) )
class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa class DruidDatasourceModelView(
DatasourceModelView, DeleteMixin, YamlExportMixin
): # noqa
datamodel = SQLAInterface(models.DruidDatasource) datamodel = SQLAInterface(models.DruidDatasource)
list_title = _('Druid Datasources') list_title = _("Druid Datasources")
show_title = _('Show Druid Datasource') show_title = _("Show Druid Datasource")
add_title = _('Add Druid Datasource') add_title = _("Add Druid Datasource")
edit_title = _('Edit Druid Datasource') edit_title = _("Edit Druid Datasource")
list_columns = [ list_columns = ["datasource_link", "cluster", "changed_by_", "modified"]
'datasource_link', 'cluster', 'changed_by_', 'modified'] order_columns = ["datasource_link", "modified"]
order_columns = ['datasource_link', 'modified']
related_views = [DruidColumnInlineView, DruidMetricInlineView] related_views = [DruidColumnInlineView, DruidMetricInlineView]
edit_columns = [ edit_columns = [
'datasource_name', 'cluster', 'description', 'owners', "datasource_name",
'is_hidden', "cluster",
'filter_select_enabled', 'fetch_values_from', "description",
'default_endpoint', 'offset', 'cache_timeout'] "owners",
search_columns = ( "is_hidden",
'datasource_name', 'cluster', 'description', 'owners', "filter_select_enabled",
) "fetch_values_from",
"default_endpoint",
"offset",
"cache_timeout",
]
search_columns = ("datasource_name", "cluster", "description", "owners")
add_columns = edit_columns add_columns = edit_columns
show_columns = add_columns + ['perm', 'slices'] show_columns = add_columns + ["perm", "slices"]
page_size = 500 page_size = 500
base_order = ('datasource_name', 'asc') base_order = ("datasource_name", "asc")
description_columns = { description_columns = {
'slices': _( "slices": _(
'The list of charts associated with this table. By ' "The list of charts associated with this table. By "
'altering this datasource, you may change how these associated ' "altering this datasource, you may change how these associated "
'charts behave. ' "charts behave. "
'Also note that charts need to point to a datasource, so ' "Also note that charts need to point to a datasource, so "
'this form will fail at saving if removing charts from a ' "this form will fail at saving if removing charts from a "
'datasource. If you want to change the datasource for a chart, ' "datasource. If you want to change the datasource for a chart, "
"overwrite the chart from the 'explore view'"), "overwrite the chart from the 'explore view'"
'offset': _('Timezone offset (in hours) for this datasource'), ),
'description': Markup( "offset": _("Timezone offset (in hours) for this datasource"),
"description": Markup(
'Supports <a href="' 'Supports <a href="'
'https://daringfireball.net/projects/markdown/">markdown</a>'), 'https://daringfireball.net/projects/markdown/">markdown</a>'
'fetch_values_from': _( ),
'Time expression to use as a predicate when retrieving ' "fetch_values_from": _(
'distinct values to populate the filter component. ' "Time expression to use as a predicate when retrieving "
'Only applies when `Enable Filter Select` is on. If ' "distinct values to populate the filter component. "
'you enter `7 days ago`, the distinct list of values in ' "Only applies when `Enable Filter Select` is on. If "
'the filter will be populated based on the distinct value over ' "you enter `7 days ago`, the distinct list of values in "
'the past week'), "the filter will be populated based on the distinct value over "
'filter_select_enabled': _( "the past week"
),
"filter_select_enabled": _(
"Whether to populate the filter's dropdown in the explore " "Whether to populate the filter's dropdown in the explore "
"view's filter section with a list of distinct values fetched " "view's filter section with a list of distinct values fetched "
'from the backend on the fly'), "from the backend on the fly"
'default_endpoint': _( ),
'Redirects to this endpoint when clicking on the datasource ' "default_endpoint": _(
'from the datasource list'), "Redirects to this endpoint when clicking on the datasource "
'cache_timeout': _( "from the datasource list"
'Duration (in seconds) of the caching timeout for this datasource. ' ),
'A timeout of 0 indicates that the cache never expires. ' "cache_timeout": _(
'Note this defaults to the cluster timeout if undefined.'), "Duration (in seconds) of the caching timeout for this datasource. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the cluster timeout if undefined."
),
} }
base_filters = [['id', DatasourceFilter, lambda: []]] base_filters = [["id", DatasourceFilter, lambda: []]]
label_columns = { label_columns = {
'slices': _('Associated Charts'), "slices": _("Associated Charts"),
'datasource_link': _('Data Source'), "datasource_link": _("Data Source"),
'cluster': _('Cluster'), "cluster": _("Cluster"),
'description': _('Description'), "description": _("Description"),
'owners': _('Owners'), "owners": _("Owners"),
'is_hidden': _('Is Hidden'), "is_hidden": _("Is Hidden"),
'filter_select_enabled': _('Enable Filter Select'), "filter_select_enabled": _("Enable Filter Select"),
'default_endpoint': _('Default Endpoint'), "default_endpoint": _("Default Endpoint"),
'offset': _('Time Offset'), "offset": _("Time Offset"),
'cache_timeout': _('Cache Timeout'), "cache_timeout": _("Cache Timeout"),
'datasource_name': _('Datasource Name'), "datasource_name": _("Datasource Name"),
'fetch_values_from': _('Fetch Values From'), "fetch_values_from": _("Fetch Values From"),
'changed_by_': _('Changed By'), "changed_by_": _("Changed By"),
'modified': _('Modified'), "modified": _("Modified"),
} }
def pre_add(self, datasource): def pre_add(self, datasource):
with db.session.no_autoflush: with db.session.no_autoflush:
query = ( query = db.session.query(models.DruidDatasource).filter(
db.session.query(models.DruidDatasource) models.DruidDatasource.datasource_name == datasource.datasource_name,
.filter(models.DruidDatasource.datasource_name == models.DruidDatasource.cluster_name == datasource.cluster.id,
datasource.datasource_name,
models.DruidDatasource.cluster_name ==
datasource.cluster.id)
) )
if db.session.query(query.exists()).scalar(): if db.session.query(query.exists()).scalar():
raise Exception(get_datasource_exist_error_msg( raise Exception(get_datasource_exist_error_msg(datasource.full_name))
datasource.full_name))
def post_add(self, datasource): def post_add(self, datasource):
datasource.refresh_metrics() datasource.refresh_metrics()
security_manager.add_permission_view_menu( security_manager.add_permission_view_menu(
'datasource_access', "datasource_access", datasource.get_perm()
datasource.get_perm(),
) )
if datasource.schema: if datasource.schema:
security_manager.add_permission_view_menu( security_manager.add_permission_view_menu(
'schema_access', "schema_access", datasource.schema_perm
datasource.schema_perm,
) )
def post_update(self, datasource): def post_update(self, datasource):
@ -362,22 +402,23 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
appbuilder.add_view( appbuilder.add_view(
DruidDatasourceModelView, DruidDatasourceModelView,
'Druid Datasources', "Druid Datasources",
label=__('Druid Datasources'), label=__("Druid Datasources"),
category='Sources', category="Sources",
category_label=__('Sources'), category_label=__("Sources"),
icon='fa-cube') icon="fa-cube",
)
class Druid(BaseSupersetView): class Druid(BaseSupersetView):
"""The base views for Superset!""" """The base views for Superset!"""
@has_access @has_access
@expose('/refresh_datasources/') @expose("/refresh_datasources/")
def refresh_datasources(self, refreshAll=True): def refresh_datasources(self, refreshAll=True):
"""endpoint that refreshes druid datasources metadata""" """endpoint that refreshes druid datasources metadata"""
session = db.session() session = db.session()
DruidCluster = ConnectorRegistry.sources['druid'].cluster_class DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
for cluster in session.query(DruidCluster).all(): for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name cluster_name = cluster.cluster_name
valid_cluster = True valid_cluster = True
@ -387,21 +428,25 @@ class Druid(BaseSupersetView):
valid_cluster = False valid_cluster = False
flash( flash(
"Error while processing cluster '{}'\n{}".format( "Error while processing cluster '{}'\n{}".format(
cluster_name, utils.error_msg_from_exception(e)), cluster_name, utils.error_msg_from_exception(e)
'danger') ),
"danger",
)
logging.exception(e) logging.exception(e)
pass pass
if valid_cluster: if valid_cluster:
cluster.metadata_last_refreshed = datetime.now() cluster.metadata_last_refreshed = datetime.now()
flash( flash(
_('Refreshed metadata from cluster [{}]').format( _("Refreshed metadata from cluster [{}]").format(
cluster.cluster_name), cluster.cluster_name
'info') ),
"info",
)
session.commit() session.commit()
return redirect('/druiddatasourcemodelview/list/') return redirect("/druiddatasourcemodelview/list/")
@has_access @has_access
@expose('/scan_new_datasources/') @expose("/scan_new_datasources/")
def scan_new_datasources(self): def scan_new_datasources(self):
""" """
Calling this endpoint will cause a scan for new Calling this endpoint will cause a scan for new
@ -413,21 +458,23 @@ class Druid(BaseSupersetView):
appbuilder.add_view_no_menu(Druid) appbuilder.add_view_no_menu(Druid)
appbuilder.add_link( appbuilder.add_link(
'Scan New Datasources', "Scan New Datasources",
label=__('Scan New Datasources'), label=__("Scan New Datasources"),
href='/druid/scan_new_datasources/', href="/druid/scan_new_datasources/",
category='Sources', category="Sources",
category_label=__('Sources'), category_label=__("Sources"),
category_icon='fa-database', category_icon="fa-database",
icon='fa-refresh') icon="fa-refresh",
)
appbuilder.add_link( appbuilder.add_link(
'Refresh Druid Metadata', "Refresh Druid Metadata",
label=__('Refresh Druid Metadata'), label=__("Refresh Druid Metadata"),
href='/druid/refresh_datasources/', href="/druid/refresh_datasources/",
category='Sources', category="Sources",
category_label=__('Sources'), category_label=__("Sources"),
category_icon='fa-database', category_icon="fa-database",
icon='fa-cog') icon="fa-cog",
)
appbuilder.add_separator('Sources') appbuilder.add_separator("Sources")

File diff suppressed because it is too large Load Diff

View File

@ -32,8 +32,12 @@ from superset import appbuilder, db, security_manager
from superset.connectors.base.views import DatasourceModelView from superset.connectors.base.views import DatasourceModelView
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.base import ( from superset.views.base import (
DatasourceFilter, DeleteMixin, get_datasource_exist_error_msg, DatasourceFilter,
ListWidgetWithCheckboxes, SupersetModelView, YamlExportMixin, DeleteMixin,
get_datasource_exist_error_msg,
ListWidgetWithCheckboxes,
SupersetModelView,
YamlExportMixin,
) )
from . import models from . import models
@ -43,79 +47,103 @@ logger = logging.getLogger(__name__)
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.TableColumn) datamodel = SQLAInterface(models.TableColumn)
list_title = _('Columns') list_title = _("Columns")
show_title = _('Show Column') show_title = _("Show Column")
add_title = _('Add Column') add_title = _("Add Column")
edit_title = _('Edit Column') edit_title = _("Edit Column")
can_delete = False can_delete = False
list_widget = ListWidgetWithCheckboxes list_widget = ListWidgetWithCheckboxes
edit_columns = [ edit_columns = [
'column_name', 'verbose_name', 'description', "column_name",
'type', 'groupby', 'filterable', "verbose_name",
'table', 'expression', "description",
'is_dttm', 'python_date_format', 'database_expression'] "type",
"groupby",
"filterable",
"table",
"expression",
"is_dttm",
"python_date_format",
"database_expression",
]
add_columns = edit_columns add_columns = edit_columns
list_columns = [ list_columns = [
'column_name', 'verbose_name', 'type', 'groupby', 'filterable', "column_name",
'is_dttm'] "verbose_name",
"type",
"groupby",
"filterable",
"is_dttm",
]
page_size = 500 page_size = 500
description_columns = { description_columns = {
'is_dttm': _( "is_dttm": _(
'Whether to make this column available as a ' "Whether to make this column available as a "
'[Time Granularity] option, column has to be DATETIME or ' "[Time Granularity] option, column has to be DATETIME or "
'DATETIME-like'), "DATETIME-like"
'filterable': _( ),
'Whether this column is exposed in the `Filters` section ' "filterable": _(
'of the explore view.'), "Whether this column is exposed in the `Filters` section "
'type': _( "of the explore view."
'The data type that was inferred by the database. ' ),
'It may be necessary to input a type manually for ' "type": _(
'expression-defined columns in some cases. In most case ' "The data type that was inferred by the database. "
'users should not need to alter this.'), "It may be necessary to input a type manually for "
'expression': utils.markdown( "expression-defined columns in some cases. In most case "
'a valid, *non-aggregating* SQL expression as supported by the ' "users should not need to alter this."
'underlying backend. Example: `substr(name, 1, 1)`', True), ),
'python_date_format': utils.markdown(Markup( "expression": utils.markdown(
'The pattern of timestamp format, use ' "a valid, *non-aggregating* SQL expression as supported by the "
'<a href="https://docs.python.org/2/library/' "underlying backend. Example: `substr(name, 1, 1)`",
'datetime.html#strftime-strptime-behavior">' True,
'python datetime string pattern</a> ' ),
'expression. If time is stored in epoch ' "python_date_format": utils.markdown(
'format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` ' Markup(
'below empty if timestamp is stored in ' "The pattern of timestamp format, use "
'String or Integer(epoch) type'), True), '<a href="https://docs.python.org/2/library/'
'database_expression': utils.markdown( 'datetime.html#strftime-strptime-behavior">'
'The database expression to cast internal datetime ' "python datetime string pattern</a> "
'constants to database date/timestamp type according to the DBAPI. ' "expression. If time is stored in epoch "
'The expression should follow the pattern of ' "format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
'%Y-%m-%d %H:%M:%S, based on different DBAPI. ' "below empty if timestamp is stored in "
'The string should be a python string formatter \n' "String or Integer(epoch) type"
),
True,
),
"database_expression": utils.markdown(
"The database expression to cast internal datetime "
"constants to database date/timestamp type according to the DBAPI. "
"The expression should follow the pattern of "
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
"The string should be a python string formatter \n"
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle " "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle "
'Superset uses default expression based on DB URI if this ' "Superset uses default expression based on DB URI if this "
'field is blank.', True), "field is blank.",
True,
),
} }
label_columns = { label_columns = {
'column_name': _('Column'), "column_name": _("Column"),
'verbose_name': _('Verbose Name'), "verbose_name": _("Verbose Name"),
'description': _('Description'), "description": _("Description"),
'groupby': _('Groupable'), "groupby": _("Groupable"),
'filterable': _('Filterable'), "filterable": _("Filterable"),
'table': _('Table'), "table": _("Table"),
'expression': _('Expression'), "expression": _("Expression"),
'is_dttm': _('Is temporal'), "is_dttm": _("Is temporal"),
'python_date_format': _('Datetime Format'), "python_date_format": _("Datetime Format"),
'database_expression': _('Database Expression'), "database_expression": _("Database Expression"),
'type': _('Type'), "type": _("Type"),
} }
add_form_extra_fields = { add_form_extra_fields = {
'table': QuerySelectField( "table": QuerySelectField(
'Table', "Table",
query_factory=lambda: db.session().query(models.SqlaTable), query_factory=lambda: db.session().query(models.SqlaTable),
allow_blank=True, allow_blank=True,
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
edit_form_extra_fields = add_form_extra_fields edit_form_extra_fields = add_form_extra_fields
@ -127,63 +155,80 @@ appbuilder.add_view_no_menu(TableColumnInlineView)
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.SqlMetric) datamodel = SQLAInterface(models.SqlMetric)
list_title = _('Metrics') list_title = _("Metrics")
show_title = _('Show Metric') show_title = _("Show Metric")
add_title = _('Add Metric') add_title = _("Add Metric")
edit_title = _('Edit Metric') edit_title = _("Edit Metric")
list_columns = ['metric_name', 'verbose_name', 'metric_type'] list_columns = ["metric_name", "verbose_name", "metric_type"]
edit_columns = [ edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type', "metric_name",
'expression', 'table', 'd3format', 'is_restricted', 'warning_text'] "description",
"verbose_name",
"metric_type",
"expression",
"table",
"d3format",
"is_restricted",
"warning_text",
]
description_columns = { description_columns = {
'expression': utils.markdown( "expression": utils.markdown(
'a valid, *aggregating* SQL expression as supported by the ' "a valid, *aggregating* SQL expression as supported by the "
'underlying backend. Example: `count(DISTINCT userid)`', True), "underlying backend. Example: `count(DISTINCT userid)`",
'is_restricted': _('Whether access to this metric is restricted ' True,
'to certain roles. Only roles with the permission ' ),
"'metric access on XXX (the name of this metric)' " "is_restricted": _(
'are allowed to access this metric'), "Whether access to this metric is restricted "
'd3format': utils.markdown( "to certain roles. Only roles with the permission "
'd3 formatting string as defined [here]' "'metric access on XXX (the name of this metric)' "
'(https://github.com/d3/d3-format/blob/master/README.md#format). ' "are allowed to access this metric"
'For instance, this default formatting applies in the Table ' ),
'visualization and allow for different metric to use different ' "d3format": utils.markdown(
'formats', True, "d3 formatting string as defined [here]"
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
"For instance, this default formatting applies in the Table "
"visualization and allow for different metric to use different "
"formats",
True,
), ),
} }
add_columns = edit_columns add_columns = edit_columns
page_size = 500 page_size = 500
label_columns = { label_columns = {
'metric_name': _('Metric'), "metric_name": _("Metric"),
'description': _('Description'), "description": _("Description"),
'verbose_name': _('Verbose Name'), "verbose_name": _("Verbose Name"),
'metric_type': _('Type'), "metric_type": _("Type"),
'expression': _('SQL Expression'), "expression": _("SQL Expression"),
'table': _('Table'), "table": _("Table"),
'd3format': _('D3 Format'), "d3format": _("D3 Format"),
'is_restricted': _('Is Restricted'), "is_restricted": _("Is Restricted"),
'warning_text': _('Warning Message'), "warning_text": _("Warning Message"),
} }
add_form_extra_fields = { add_form_extra_fields = {
'table': QuerySelectField( "table": QuerySelectField(
'Table', "Table",
query_factory=lambda: db.session().query(models.SqlaTable), query_factory=lambda: db.session().query(models.SqlaTable),
allow_blank=True, allow_blank=True,
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
edit_form_extra_fields = add_form_extra_fields edit_form_extra_fields = add_form_extra_fields
def post_add(self, metric): def post_add(self, metric):
if metric.is_restricted: if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm()) security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
def post_update(self, metric): def post_update(self, metric):
if metric.is_restricted: if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm()) security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
appbuilder.add_view_no_menu(SqlMetricInlineView) appbuilder.add_view_no_menu(SqlMetricInlineView)
@ -192,104 +237,114 @@ appbuilder.add_view_no_menu(SqlMetricInlineView)
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable) datamodel = SQLAInterface(models.SqlaTable)
list_title = _('Tables') list_title = _("Tables")
show_title = _('Show Table') show_title = _("Show Table")
add_title = _('Import a table definition') add_title = _("Import a table definition")
edit_title = _('Edit Table') edit_title = _("Edit Table")
list_columns = [ list_columns = ["link", "database_name", "changed_by_", "modified"]
'link', 'database_name', order_columns = ["modified"]
'changed_by_', 'modified'] add_columns = ["database", "schema", "table_name"]
order_columns = ['modified']
add_columns = ['database', 'schema', 'table_name']
edit_columns = [ edit_columns = [
'table_name', 'sql', 'filter_select_enabled', "table_name",
'fetch_values_predicate', 'database', 'schema', "sql",
'description', 'owners', "filter_select_enabled",
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout', "fetch_values_predicate",
'is_sqllab_view', 'template_params', "database",
"schema",
"description",
"owners",
"main_dttm_col",
"default_endpoint",
"offset",
"cache_timeout",
"is_sqllab_view",
"template_params",
] ]
base_filters = [['id', DatasourceFilter, lambda: []]] base_filters = [["id", DatasourceFilter, lambda: []]]
show_columns = edit_columns + ['perm', 'slices'] show_columns = edit_columns + ["perm", "slices"]
related_views = [TableColumnInlineView, SqlMetricInlineView] related_views = [TableColumnInlineView, SqlMetricInlineView]
base_order = ('changed_on', 'desc') base_order = ("changed_on", "desc")
search_columns = ( search_columns = ("database", "schema", "table_name", "owners", "is_sqllab_view")
'database', 'schema', 'table_name', 'owners', 'is_sqllab_view',
)
description_columns = { description_columns = {
'slices': _( "slices": _(
'The list of charts associated with this table. By ' "The list of charts associated with this table. By "
'altering this datasource, you may change how these associated ' "altering this datasource, you may change how these associated "
'charts behave. ' "charts behave. "
'Also note that charts need to point to a datasource, so ' "Also note that charts need to point to a datasource, so "
'this form will fail at saving if removing charts from a ' "this form will fail at saving if removing charts from a "
'datasource. If you want to change the datasource for a chart, ' "datasource. If you want to change the datasource for a chart, "
"overwrite the chart from the 'explore view'"), "overwrite the chart from the 'explore view'"
'offset': _('Timezone offset (in hours) for this datasource'), ),
'table_name': _( "offset": _("Timezone offset (in hours) for this datasource"),
'Name of the table that exists in the source database'), "table_name": _("Name of the table that exists in the source database"),
'schema': _( "schema": _(
'Schema, as used only in some databases like Postgres, Redshift ' "Schema, as used only in some databases like Postgres, Redshift " "and DB2"
'and DB2'), ),
'description': Markup( "description": Markup(
'Supports <a href="https://daringfireball.net/projects/markdown/">' 'Supports <a href="https://daringfireball.net/projects/markdown/">'
'markdown</a>'), "markdown</a>"
'sql': _(
'This fields acts a Superset view, meaning that Superset will '
'run a query against this string as a subquery.',
), ),
'fetch_values_predicate': _( "sql": _(
'Predicate applied when fetching distinct value to ' "This fields acts a Superset view, meaning that Superset will "
'populate the filter control component. Supports ' "run a query against this string as a subquery."
'jinja template syntax. Applies only when '
'`Enable Filter Select` is on.',
), ),
'default_endpoint': _( "fetch_values_predicate": _(
'Redirects to this endpoint when clicking on the table ' "Predicate applied when fetching distinct value to "
'from the table list'), "populate the filter control component. Supports "
'filter_select_enabled': _( "jinja template syntax. Applies only when "
"`Enable Filter Select` is on."
),
"default_endpoint": _(
"Redirects to this endpoint when clicking on the table "
"from the table list"
),
"filter_select_enabled": _(
"Whether to populate the filter's dropdown in the explore " "Whether to populate the filter's dropdown in the explore "
"view's filter section with a list of distinct values fetched " "view's filter section with a list of distinct values fetched "
'from the backend on the fly'), "from the backend on the fly"
'is_sqllab_view': _( ),
"Whether the table was generated by the 'Visualize' flow " "is_sqllab_view": _(
'in SQL Lab'), "Whether the table was generated by the 'Visualize' flow " "in SQL Lab"
'template_params': _( ),
'A set of parameters that become available in the query using ' "template_params": _(
'Jinja templating syntax'), "A set of parameters that become available in the query using "
'cache_timeout': _( "Jinja templating syntax"
'Duration (in seconds) of the caching timeout for this table. ' ),
'A timeout of 0 indicates that the cache never expires. ' "cache_timeout": _(
'Note this defaults to the database timeout if undefined.'), "Duration (in seconds) of the caching timeout for this table. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the database timeout if undefined."
),
} }
label_columns = { label_columns = {
'slices': _('Associated Charts'), "slices": _("Associated Charts"),
'link': _('Table'), "link": _("Table"),
'changed_by_': _('Changed By'), "changed_by_": _("Changed By"),
'database': _('Database'), "database": _("Database"),
'database_name': _('Database'), "database_name": _("Database"),
'changed_on_': _('Last Changed'), "changed_on_": _("Last Changed"),
'filter_select_enabled': _('Enable Filter Select'), "filter_select_enabled": _("Enable Filter Select"),
'schema': _('Schema'), "schema": _("Schema"),
'default_endpoint': _('Default Endpoint'), "default_endpoint": _("Default Endpoint"),
'offset': _('Offset'), "offset": _("Offset"),
'cache_timeout': _('Cache Timeout'), "cache_timeout": _("Cache Timeout"),
'table_name': _('Table Name'), "table_name": _("Table Name"),
'fetch_values_predicate': _('Fetch Values Predicate'), "fetch_values_predicate": _("Fetch Values Predicate"),
'owners': _('Owners'), "owners": _("Owners"),
'main_dttm_col': _('Main Datetime Column'), "main_dttm_col": _("Main Datetime Column"),
'description': _('Description'), "description": _("Description"),
'is_sqllab_view': _('SQL Lab View'), "is_sqllab_view": _("SQL Lab View"),
'template_params': _('Template parameters'), "template_params": _("Template parameters"),
'modified': _('Modified'), "modified": _("Modified"),
} }
edit_form_extra_fields = { edit_form_extra_fields = {
'database': QuerySelectField( "database": QuerySelectField(
'Database', "Database",
query_factory=lambda: db.session().query(models.Database), query_factory=lambda: db.session().query(models.Database),
widget=Select2Widget(extra_classes='readonly'), widget=Select2Widget(extra_classes="readonly"),
), )
} }
def pre_add(self, table): def pre_add(self, table):
@ -297,34 +352,43 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
table_query = db.session.query(models.SqlaTable).filter( table_query = db.session.query(models.SqlaTable).filter(
models.SqlaTable.table_name == table.table_name, models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema, models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id) models.SqlaTable.database_id == table.database.id,
)
if db.session.query(table_query.exists()).scalar(): if db.session.query(table_query.exists()).scalar():
raise Exception( raise Exception(get_datasource_exist_error_msg(table.full_name))
get_datasource_exist_error_msg(table.full_name))
# Fail before adding if the table can't be found # Fail before adding if the table can't be found
try: try:
table.get_sqla_table_object() table.get_sqla_table_object()
except Exception as e: except Exception as e:
logger.exception(f'Got an error in pre_add for {table.name}') logger.exception(f"Got an error in pre_add for {table.name}")
raise Exception(_( raise Exception(
'Table [{}] could not be found, ' _(
'please double check your ' "Table [{}] could not be found, "
'database connection, schema, and ' "please double check your "
'table name, error: {}').format(table.name, str(e))) "database connection, schema, and "
"table name, error: {}"
).format(table.name, str(e))
)
def post_add(self, table, flash_message=True): def post_add(self, table, flash_message=True):
table.fetch_metadata() table.fetch_metadata()
security_manager.add_permission_view_menu('datasource_access', table.get_perm()) security_manager.add_permission_view_menu("datasource_access", table.get_perm())
if table.schema: if table.schema:
security_manager.add_permission_view_menu('schema_access', table.schema_perm) security_manager.add_permission_view_menu(
"schema_access", table.schema_perm
)
if flash_message: if flash_message:
flash(_( flash(
'The table was created. ' _(
'As part of this two-phase configuration ' "The table was created. "
'process, you should now click the edit button by ' "As part of this two-phase configuration "
'the new table to configure it.'), 'info') "process, you should now click the edit button by "
"the new table to configure it."
),
"info",
)
def post_update(self, table): def post_update(self, table):
self.post_add(table, flash_message=False) self.post_add(table, flash_message=False)
@ -332,20 +396,18 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
def _delete(self, pk): def _delete(self, pk):
DeleteMixin._delete(self, pk) DeleteMixin._delete(self, pk)
@expose('/edit/<pk>', methods=['GET', 'POST']) @expose("/edit/<pk>", methods=["GET", "POST"])
@has_access @has_access
def edit(self, pk): def edit(self, pk):
"""Simple hack to redirect to explore view after saving""" """Simple hack to redirect to explore view after saving"""
resp = super(TableModelView, self).edit(pk) resp = super(TableModelView, self).edit(pk)
if isinstance(resp, str): if isinstance(resp, str):
return resp return resp
return redirect('/superset/explore/table/{}/'.format(pk)) return redirect("/superset/explore/table/{}/".format(pk))
@action( @action(
'refresh', "refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh"
__('Refresh Metadata'), )
__('Refresh column metadata'),
'fa-refresh')
def refresh(self, tables): def refresh(self, tables):
if not isinstance(tables, list): if not isinstance(tables, list):
tables = [tables] tables = [tables]
@ -360,26 +422,29 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
if len(successes) > 0: if len(successes) > 0:
success_msg = _( success_msg = _(
'Metadata refreshed for the following table(s): %(tables)s', "Metadata refreshed for the following table(s): %(tables)s",
tables=', '.join([t.table_name for t in successes])) tables=", ".join([t.table_name for t in successes]),
flash(success_msg, 'info') )
flash(success_msg, "info")
if len(failures) > 0: if len(failures) > 0:
failure_msg = _( failure_msg = _(
'Unable to retrieve metadata for the following table(s): %(tables)s', "Unable to retrieve metadata for the following table(s): %(tables)s",
tables=', '.join([t.table_name for t in failures])) tables=", ".join([t.table_name for t in failures]),
flash(failure_msg, 'danger') )
flash(failure_msg, "danger")
return redirect('/tablemodelview/list/') return redirect("/tablemodelview/list/")
appbuilder.add_view_no_menu(TableModelView) appbuilder.add_view_no_menu(TableModelView)
appbuilder.add_link( appbuilder.add_link(
'Tables', "Tables",
label=__('Tables'), label=__("Tables"),
href='/tablemodelview/list/?_flt_1_is_sqllab_view=y', href="/tablemodelview/list/?_flt_1_is_sqllab_view=y",
icon='fa-table', icon="fa-table",
category='Sources', category="Sources",
category_label=__('Sources'), category_label=__("Sources"),
category_icon='fa-table') category_icon="fa-table",
)
appbuilder.add_separator('Sources') appbuilder.add_separator("Sources")

View File

@ -28,6 +28,6 @@ from .multiformat_time_series import load_multiformat_time_series # noqa
from .paris import load_paris_iris_geojson # noqa from .paris import load_paris_iris_geojson # noqa
from .random_time_series import load_random_time_series_data # noqa from .random_time_series import load_random_time_series_data # noqa
from .sf_population_polygons import load_sf_population_polygons # noqa from .sf_population_polygons import load_sf_population_polygons # noqa
from .tabbed_dashboard import load_tabbed_dashboard # noqa from .tabbed_dashboard import load_tabbed_dashboard # noqa
from .unicode_test_data import load_unicode_test_data # noqa from .unicode_test_data import load_unicode_test_data # noqa
from .world_bank import load_world_bank_health_n_pop # noqa from .world_bank import load_world_bank_health_n_pop # noqa

View File

@ -26,30 +26,31 @@ from .helpers import TBL, get_example_data
def load_bart_lines(): def load_bart_lines():
tbl_name = 'bart_lines' tbl_name = "bart_lines"
content = get_example_data('bart-lines.json.gz') content = get_example_data("bart-lines.json.gz")
df = pd.read_json(content, encoding='latin-1') df = pd.read_json(content, encoding="latin-1")
df['path_json'] = df.path.map(json.dumps) df["path_json"] = df.path.map(json.dumps)
df['polyline'] = df.path.map(polyline.encode) df["polyline"] = df.path.map(polyline.encode)
del df['path'] del df["path"]
df.to_sql( df.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'color': String(255), "color": String(255),
'name': String(255), "name": String(255),
'polyline': Text, "polyline": Text,
'path_json': Text, "path_json": Text,
}, },
index=False) index=False,
print('Creating table {} reference'.format(tbl_name)) )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = 'BART lines' tbl.description = "BART lines"
tbl.database = get_or_create_main_db() tbl.database = get_or_create_main_db()
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()

View File

@ -38,46 +38,46 @@ from .helpers import (
def load_birth_names(): def load_birth_names():
"""Loading birth name dataset from a zip file in the repo""" """Loading birth name dataset from a zip file in the repo"""
data = get_example_data('birth_names.json.gz') data = get_example_data("birth_names.json.gz")
pdf = pd.read_json(data) pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='ms') pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.to_sql( pdf.to_sql(
'birth_names', "birth_names",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'ds': DateTime, "ds": DateTime,
'gender': String(16), "gender": String(16),
'state': String(10), "state": String(10),
'name': String(255), "name": String(255),
}, },
index=False) index=False,
print('Done loading table!') )
print('-' * 80) print("Done loading table!")
print("-" * 80)
print('Creating table [birth_names] reference') print("Creating table [birth_names] reference")
obj = db.session.query(TBL).filter_by(table_name='birth_names').first() obj = db.session.query(TBL).filter_by(table_name="birth_names").first()
if not obj: if not obj:
obj = TBL(table_name='birth_names') obj = TBL(table_name="birth_names")
obj.main_dttm_col = 'ds' obj.main_dttm_col = "ds"
obj.database = get_or_create_main_db() obj.database = get_or_create_main_db()
obj.filter_select_enabled = True obj.filter_select_enabled = True
if not any(col.column_name == 'num_california' for col in obj.columns): if not any(col.column_name == "num_california" for col in obj.columns):
col_state = str(column('state').compile(db.engine)) col_state = str(column("state").compile(db.engine))
col_num = str(column('num').compile(db.engine)) col_num = str(column("num").compile(db.engine))
obj.columns.append(TableColumn( obj.columns.append(
column_name='num_california', TableColumn(
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END", column_name="num_california",
)) expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
)
)
if not any(col.metric_name == 'sum__num' for col in obj.metrics): if not any(col.metric_name == "sum__num" for col in obj.metrics):
col = str(column('num').compile(db.engine)) col = str(column("num").compile(db.engine))
obj.metrics.append(SqlMetric( obj.metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
metric_name='sum__num',
expression=f'SUM({col})',
))
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
@ -85,149 +85,149 @@ def load_birth_names():
tbl = obj tbl = obj
defaults = { defaults = {
'compare_lag': '10', "compare_lag": "10",
'compare_suffix': 'o10Y', "compare_suffix": "o10Y",
'limit': '25', "limit": "25",
'granularity_sqla': 'ds', "granularity_sqla": "ds",
'groupby': [], "groupby": [],
'metric': 'sum__num', "metric": "sum__num",
'metrics': ['sum__num'], "metrics": ["sum__num"],
'row_limit': config.get('ROW_LIMIT'), "row_limit": config.get("ROW_LIMIT"),
'since': '100 years ago', "since": "100 years ago",
'until': 'now', "until": "now",
'viz_type': 'table', "viz_type": "table",
'where': '', "where": "",
'markup_type': 'markdown', "markup_type": "markdown",
} }
admin = security_manager.find_user('admin') admin = security_manager.find_user("admin")
print('Creating some slices') print("Creating some slices")
slices = [ slices = [
Slice( Slice(
slice_name='Girls', slice_name="Girls",
viz_type='table', viz_type="table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['name'], groupby=["name"],
filters=[{ filters=[{"col": "gender", "op": "in", "val": ["girl"]}],
'col': 'gender',
'op': 'in',
'val': ['girl'],
}],
row_limit=50, row_limit=50,
timeseries_limit_metric='sum__num')), timeseries_limit_metric="sum__num",
),
),
Slice( Slice(
slice_name='Boys', slice_name="Boys",
viz_type='table', viz_type="table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['name'], groupby=["name"],
filters=[{ filters=[{"col": "gender", "op": "in", "val": ["boy"]}],
'col': 'gender', row_limit=50,
'op': 'in', ),
'val': ['boy'], ),
}],
row_limit=50)),
Slice( Slice(
slice_name='Participants', slice_name="Participants",
viz_type='big_number', viz_type="big_number",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='big_number', granularity_sqla='ds', viz_type="big_number",
compare_lag='5', compare_suffix='over 5Y')), granularity_sqla="ds",
compare_lag="5",
compare_suffix="over 5Y",
),
),
Slice( Slice(
slice_name='Genders', slice_name="Genders",
viz_type='pie', viz_type="pie",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(defaults, viz_type="pie", groupby=["gender"]),
defaults, ),
viz_type='pie', groupby=['gender'])),
Slice( Slice(
slice_name='Genders by State', slice_name="Genders by State",
viz_type='dist_bar', viz_type="dist_bar",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
adhoc_filters=[ adhoc_filters=[
{ {
'clause': 'WHERE', "clause": "WHERE",
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'filterOptionName': '2745eae5', "filterOptionName": "2745eae5",
'comparator': ['other'], "comparator": ["other"],
'operator': 'not in', "operator": "not in",
'subject': 'state', "subject": "state",
}, }
], ],
viz_type='dist_bar', viz_type="dist_bar",
metrics=[ metrics=[
{ {
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {"column_name": "sum_boys", "type": "BIGINT(20)"},
'column_name': 'sum_boys', "aggregate": "SUM",
'type': 'BIGINT(20)', "label": "Boys",
}, "optionName": "metric_11",
'aggregate': 'SUM',
'label': 'Boys',
'optionName': 'metric_11',
}, },
{ {
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {"column_name": "sum_girls", "type": "BIGINT(20)"},
'column_name': 'sum_girls', "aggregate": "SUM",
'type': 'BIGINT(20)', "label": "Girls",
}, "optionName": "metric_12",
'aggregate': 'SUM',
'label': 'Girls',
'optionName': 'metric_12',
}, },
], ],
groupby=['state'])), groupby=["state"],
),
),
Slice( Slice(
slice_name='Trends', slice_name="Trends",
viz_type='line', viz_type="line",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='line', groupby=['name'], viz_type="line",
granularity_sqla='ds', rich_tooltip=True, show_legend=True)), groupby=["name"],
granularity_sqla="ds",
rich_tooltip=True,
show_legend=True,
),
),
Slice( Slice(
slice_name='Average and Sum Trends', slice_name="Average and Sum Trends",
viz_type='dual_line', viz_type="dual_line",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='dual_line', viz_type="dual_line",
metric={ metric={
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {"column_name": "num", "type": "BIGINT(20)"},
'column_name': 'num', "aggregate": "AVG",
'type': 'BIGINT(20)', "label": "AVG(num)",
}, "optionName": "metric_vgops097wej_g8uff99zhk7",
'aggregate': 'AVG',
'label': 'AVG(num)',
'optionName': 'metric_vgops097wej_g8uff99zhk7',
}, },
metric_2='sum__num', metric_2="sum__num",
granularity_sqla='ds')), granularity_sqla="ds",
),
),
Slice( Slice(
slice_name='Title', slice_name="Title",
viz_type='markup', viz_type="markup",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='markup', markup_type='html', viz_type="markup",
markup_type="html",
code="""\ code="""\
<div style='text-align:center'> <div style='text-align:center'>
<h1>Birth Names Dashboard</h1> <h1>Birth Names Dashboard</h1>
@ -237,135 +237,156 @@ def load_birth_names():
</p> </p>
<img src='/static/assets/images/babytux.jpg'> <img src='/static/assets/images/babytux.jpg'>
</div> </div>
""")), """,
),
),
Slice( Slice(
slice_name='Name Cloud', slice_name="Name Cloud",
viz_type='word_cloud', viz_type="word_cloud",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='word_cloud', size_from='10', viz_type="word_cloud",
series='name', size_to='70', rotation='square', size_from="10",
limit='100')), series="name",
size_to="70",
rotation="square",
limit="100",
),
),
Slice( Slice(
slice_name='Pivot Table', slice_name="Pivot Table",
viz_type='pivot_table', viz_type="pivot_table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='pivot_table', metrics=['sum__num'], viz_type="pivot_table",
groupby=['name'], columns=['state'])), metrics=["sum__num"],
groupby=["name"],
columns=["state"],
),
),
Slice( Slice(
slice_name='Number of Girls', slice_name="Number of Girls",
viz_type='big_number_total', viz_type="big_number_total",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='big_number_total', granularity_sqla='ds', viz_type="big_number_total",
filters=[{ granularity_sqla="ds",
'col': 'gender', filters=[{"col": "gender", "op": "in", "val": ["girl"]}],
'op': 'in', subheader="total female participants",
'val': ['girl'], ),
}], ),
subheader='total female participants')),
Slice( Slice(
slice_name='Number of California Births', slice_name="Number of California Births",
viz_type='big_number_total', viz_type="big_number_total",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
metric={ metric={
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {
'column_name': 'num_california', "column_name": "num_california",
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
}, },
'aggregate': 'SUM', "aggregate": "SUM",
'label': 'SUM(num_california)', "label": "SUM(num_california)",
}, },
viz_type='big_number_total', viz_type="big_number_total",
granularity_sqla='ds')), granularity_sqla="ds",
),
),
Slice( Slice(
slice_name='Top 10 California Names Timeseries', slice_name="Top 10 California Names Timeseries",
viz_type='line', viz_type="line",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
metrics=[{ metrics=[
'expressionType': 'SIMPLE', {
'column': { "expressionType": "SIMPLE",
'column_name': 'num_california', "column": {
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END", "column_name": "num_california",
}, "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
'aggregate': 'SUM', },
'label': 'SUM(num_california)', "aggregate": "SUM",
}], "label": "SUM(num_california)",
viz_type='line', }
granularity_sqla='ds', ],
groupby=['name'], viz_type="line",
granularity_sqla="ds",
groupby=["name"],
timeseries_limit_metric={ timeseries_limit_metric={
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {
'column_name': 'num_california', "column_name": "num_california",
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
}, },
'aggregate': 'SUM', "aggregate": "SUM",
'label': 'SUM(num_california)', "label": "SUM(num_california)",
}, },
limit='10')), limit="10",
),
),
Slice( Slice(
slice_name='Names Sorted by Num in California', slice_name="Names Sorted by Num in California",
viz_type='table', viz_type="table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['name'], groupby=["name"],
row_limit=50, row_limit=50,
timeseries_limit_metric={ timeseries_limit_metric={
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {
'column_name': 'num_california', "column_name": "num_california",
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
}, },
'aggregate': 'SUM', "aggregate": "SUM",
'label': 'SUM(num_california)', "label": "SUM(num_california)",
})), },
),
),
Slice( Slice(
slice_name='Num Births Trend', slice_name="Num Births Trend",
viz_type='line', viz_type="line",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(defaults, viz_type="line"),
defaults, ),
viz_type='line')),
Slice( Slice(
slice_name='Daily Totals', slice_name="Daily Totals",
viz_type='table', viz_type="table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
created_by=admin, created_by=admin,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
groupby=['ds'], groupby=["ds"],
since='40 years ago', since="40 years ago",
until='now', until="now",
viz_type='table')), viz_type="table",
),
),
] ]
for slc in slices: for slc in slices:
merge_slice(slc) merge_slice(slc)
print('Creating a dashboard') print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title='Births').first() dash = db.session.query(Dash).filter_by(dashboard_title="Births").first()
if not dash: if not dash:
dash = Dash() dash = Dash()
js = textwrap.dedent("""\ js = textwrap.dedent(
# pylint: disable=line-too-long
"""\
{ {
"CHART-0dd270f0": { "CHART-0dd270f0": {
"meta": { "meta": {
@ -614,13 +635,15 @@ def load_birth_names():
}, },
"DASHBOARD_VERSION_KEY": "v2" "DASHBOARD_VERSION_KEY": "v2"
} }
""") """
# pylint: enable=line-too-long
)
pos = json.loads(js) pos = json.loads(js)
# dashboard v2 doesn't allow add markup slice # dashboard v2 doesn't allow add markup slice
dash.slices = [slc for slc in slices if slc.viz_type != 'markup'] dash.slices = [slc for slc in slices if slc.viz_type != "markup"]
update_slice_ids(pos, dash.slices) update_slice_ids(pos, dash.slices)
dash.dashboard_title = 'Births' dash.dashboard_title = "Births"
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slug = 'births' dash.slug = "births"
db.session.merge(dash) db.session.merge(dash)
db.session.commit() db.session.commit()

File diff suppressed because it is too large Load Diff

View File

@ -36,75 +36,71 @@ from .helpers import (
def load_country_map_data(): def load_country_map_data():
"""Loading data for map with country map""" """Loading data for map with country map"""
csv_bytes = get_example_data( csv_bytes = get_example_data(
'birth_france_data_for_country_map.csv', is_gzip=False, make_bytes=True) "birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True
data = pd.read_csv(csv_bytes, encoding='utf-8') )
data['dttm'] = datetime.datetime.now().date() data = pd.read_csv(csv_bytes, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql( # pylint: disable=no-member data.to_sql( # pylint: disable=no-member
'birth_france_by_region', "birth_france_by_region",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'DEPT_ID': String(10), "DEPT_ID": String(10),
'2003': BigInteger, "2003": BigInteger,
'2004': BigInteger, "2004": BigInteger,
'2005': BigInteger, "2005": BigInteger,
'2006': BigInteger, "2006": BigInteger,
'2007': BigInteger, "2007": BigInteger,
'2008': BigInteger, "2008": BigInteger,
'2009': BigInteger, "2009": BigInteger,
'2010': BigInteger, "2010": BigInteger,
'2011': BigInteger, "2011": BigInteger,
'2012': BigInteger, "2012": BigInteger,
'2013': BigInteger, "2013": BigInteger,
'2014': BigInteger, "2014": BigInteger,
'dttm': Date(), "dttm": Date(),
}, },
index=False) index=False,
print('Done loading table!') )
print('-' * 80) print("Done loading table!")
print('Creating table reference') print("-" * 80)
obj = db.session.query(TBL).filter_by(table_name='birth_france_by_region').first() print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name="birth_france_by_region").first()
if not obj: if not obj:
obj = TBL(table_name='birth_france_by_region') obj = TBL(table_name="birth_france_by_region")
obj.main_dttm_col = 'dttm' obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db() obj.database = utils.get_or_create_main_db()
if not any(col.metric_name == 'avg__2004' for col in obj.metrics): if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column('2004').compile(db.engine)) col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric( obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
metric_name='avg__2004',
expression=f'AVG({col})',
))
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj
slice_data = { slice_data = {
'granularity_sqla': '', "granularity_sqla": "",
'since': '', "since": "",
'until': '', "until": "",
'where': '', "where": "",
'viz_type': 'country_map', "viz_type": "country_map",
'entity': 'DEPT_ID', "entity": "DEPT_ID",
'metric': { "metric": {
'expressionType': 'SIMPLE', "expressionType": "SIMPLE",
'column': { "column": {"type": "INT", "column_name": "2004"},
'type': 'INT', "aggregate": "AVG",
'column_name': '2004', "label": "Boys",
}, "optionName": "metric_112342",
'aggregate': 'AVG',
'label': 'Boys',
'optionName': 'metric_112342',
}, },
'row_limit': 500000, "row_limit": 500000,
} }
print('Creating a slice') print("Creating a slice")
slc = Slice( slc = Slice(
slice_name='Birth in France by department in 2016', slice_name="Birth in France by department in 2016",
viz_type='country_map', viz_type="country_map",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -22,12 +22,13 @@ from superset.models.core import CssTemplate
def load_css_templates(): def load_css_templates():
"""Loads 2 css templates to demonstrate the feature""" """Loads 2 css templates to demonstrate the feature"""
print('Creating default CSS templates') print("Creating default CSS templates")
obj = db.session.query(CssTemplate).filter_by(template_name='Flat').first() obj = db.session.query(CssTemplate).filter_by(template_name="Flat").first()
if not obj: if not obj:
obj = CssTemplate(template_name='Flat') obj = CssTemplate(template_name="Flat")
css = textwrap.dedent("""\ css = textwrap.dedent(
"""\
.gridster div.widget { .gridster div.widget {
transition: background-color 0.5s ease; transition: background-color 0.5s ease;
background-color: #FAFAFA; background-color: #FAFAFA;
@ -58,16 +59,17 @@ def load_css_templates():
'#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e', '#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e',
]; ];
*/ */
""") """
)
obj.css = css obj.css = css
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
obj = ( obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
db.session.query(CssTemplate).filter_by(template_name='Courier Black').first())
if not obj: if not obj:
obj = CssTemplate(template_name='Courier Black') obj = CssTemplate(template_name="Courier Black")
css = textwrap.dedent("""\ css = textwrap.dedent(
"""\
.gridster div.widget { .gridster div.widget {
transition: background-color 0.5s ease; transition: background-color 0.5s ease;
background-color: #EEE; background-color: #EEE;
@ -113,7 +115,8 @@ def load_css_templates():
'#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e', '#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e',
]; ];
*/ */
""") """
)
obj.css = css obj.css = css
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()

View File

@ -18,21 +18,9 @@
import json import json
from superset import db from superset import db
from .helpers import ( from .helpers import Dash, get_slice_json, merge_slice, Slice, TBL, update_slice_ids
Dash,
get_slice_json,
merge_slice,
Slice,
TBL,
update_slice_ids,
)
COLOR_RED = { COLOR_RED = {"r": 205, "g": 0, "b": 3, "a": 0.82}
'r': 205,
'g': 0,
'b': 3,
'a': 0.82,
}
POSITION_JSON = """\ POSITION_JSON = """\
{ {
"CHART-3afd9d70": { "CHART-3afd9d70": {
@ -177,46 +165,42 @@ POSITION_JSON = """\
def load_deck_dash(): def load_deck_dash():
print('Loading deck.gl dashboard') print("Loading deck.gl dashboard")
slices = [] slices = []
tbl = db.session.query(TBL).filter_by(table_name='long_lat').first() tbl = db.session.query(TBL).filter_by(table_name="long_lat").first()
slice_data = { slice_data = {
'spatial': { "spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
'type': 'latlong', "color_picker": COLOR_RED,
'lonCol': 'LON', "datasource": "5__table",
'latCol': 'LAT', "filters": [],
"granularity_sqla": None,
"groupby": [],
"having": "",
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"multiplier": 10,
"point_radius_fixed": {"type": "metric", "value": "count"},
"point_unit": "square_m",
"min_radius": 1,
"row_limit": 5000,
"time_range": " : ",
"size": "count",
"time_grain_sqla": None,
"viewport": {
"bearing": -4.952916738791771,
"latitude": 37.78926922909199,
"longitude": -122.42613341901688,
"pitch": 4.750411100577438,
"zoom": 12.729132798697304,
}, },
'color_picker': COLOR_RED, "viz_type": "deck_scatter",
'datasource': '5__table', "where": "",
'filters': [],
'granularity_sqla': None,
'groupby': [],
'having': '',
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'multiplier': 10,
'point_radius_fixed': {'type': 'metric', 'value': 'count'},
'point_unit': 'square_m',
'min_radius': 1,
'row_limit': 5000,
'time_range': ' : ',
'size': 'count',
'time_grain_sqla': None,
'viewport': {
'bearing': -4.952916738791771,
'latitude': 37.78926922909199,
'longitude': -122.42613341901688,
'pitch': 4.750411100577438,
'zoom': 12.729132798697304,
},
'viz_type': 'deck_scatter',
'where': '',
} }
print('Creating Scatterplot slice') print("Creating Scatterplot slice")
slc = Slice( slc = Slice(
slice_name='Scatterplot', slice_name="Scatterplot",
viz_type='deck_scatter', viz_type="deck_scatter",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -224,46 +208,37 @@ def load_deck_dash():
slices.append(slc) slices.append(slc)
slice_data = { slice_data = {
'point_unit': 'square_m', "point_unit": "square_m",
'filters': [], "filters": [],
'row_limit': 5000, "row_limit": 5000,
'spatial': { "spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
'type': 'latlong', "mapbox_style": "mapbox://styles/mapbox/dark-v9",
'lonCol': 'LON', "granularity_sqla": None,
'latCol': 'LAT', "size": "count",
"viz_type": "deck_screengrid",
"time_range": "No filter",
"point_radius": "Auto",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 20,
"where": "",
"having": "",
"viewport": {
"zoom": 14.161641703941438,
"longitude": -122.41827069521386,
"bearing": -4.952916738791771,
"latitude": 37.76024135844065,
"pitch": 4.750411100577438,
}, },
'mapbox_style': 'mapbox://styles/mapbox/dark-v9', "point_radius_fixed": {"type": "fix", "value": 2000},
'granularity_sqla': None, "datasource": "5__table",
'size': 'count', "time_grain_sqla": None,
'viz_type': 'deck_screengrid', "groupby": [],
'time_range': 'No filter',
'point_radius': 'Auto',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 20,
'where': '',
'having': '',
'viewport': {
'zoom': 14.161641703941438,
'longitude': -122.41827069521386,
'bearing': -4.952916738791771,
'latitude': 37.76024135844065,
'pitch': 4.750411100577438,
},
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
} }
print('Creating Screen Grid slice') print("Creating Screen Grid slice")
slc = Slice( slc = Slice(
slice_name='Screen grid', slice_name="Screen grid",
viz_type='deck_screengrid', viz_type="deck_screengrid",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -271,47 +246,38 @@ def load_deck_dash():
slices.append(slc) slices.append(slc)
slice_data = { slice_data = {
'spatial': { "spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
'type': 'latlong', "filters": [],
'lonCol': 'LON', "row_limit": 5000,
'latCol': 'LAT', "mapbox_style": "mapbox://styles/mapbox/streets-v9",
"granularity_sqla": None,
"size": "count",
"viz_type": "deck_hex",
"time_range": "No filter",
"point_radius_unit": "Pixels",
"point_radius": "Auto",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 40,
"extruded": True,
"having": "",
"viewport": {
"latitude": 37.789795085160335,
"pitch": 54.08961642447763,
"zoom": 13.835465702403654,
"longitude": -122.40632230075536,
"bearing": -2.3984797349335167,
}, },
'filters': [], "where": "",
'row_limit': 5000, "point_radius_fixed": {"type": "fix", "value": 2000},
'mapbox_style': 'mapbox://styles/mapbox/streets-v9', "datasource": "5__table",
'granularity_sqla': None, "time_grain_sqla": None,
'size': 'count', "groupby": [],
'viz_type': 'deck_hex',
'time_range': 'No filter',
'point_radius_unit': 'Pixels',
'point_radius': 'Auto',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 40,
'extruded': True,
'having': '',
'viewport': {
'latitude': 37.789795085160335,
'pitch': 54.08961642447763,
'zoom': 13.835465702403654,
'longitude': -122.40632230075536,
'bearing': -2.3984797349335167,
},
'where': '',
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
} }
print('Creating Hex slice') print("Creating Hex slice")
slc = Slice( slc = Slice(
slice_name='Hexagons', slice_name="Hexagons",
viz_type='deck_hex', viz_type="deck_hex",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -319,120 +285,98 @@ def load_deck_dash():
slices.append(slc) slices.append(slc)
slice_data = { slice_data = {
'spatial': { "spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
'type': 'latlong', "filters": [],
'lonCol': 'LON', "row_limit": 5000,
'latCol': 'LAT', "mapbox_style": "mapbox://styles/mapbox/satellite-streets-v9",
"granularity_sqla": None,
"size": "count",
"viz_type": "deck_grid",
"point_radius_unit": "Pixels",
"point_radius": "Auto",
"time_range": "No filter",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 120,
"extruded": True,
"having": "",
"viewport": {
"longitude": -122.42066918995666,
"bearing": 155.80099696026355,
"zoom": 12.699690845482069,
"latitude": 37.7942314882596,
"pitch": 53.470800300695146,
}, },
'filters': [], "where": "",
'row_limit': 5000, "point_radius_fixed": {"type": "fix", "value": 2000},
'mapbox_style': 'mapbox://styles/mapbox/satellite-streets-v9', "datasource": "5__table",
'granularity_sqla': None, "time_grain_sqla": None,
'size': 'count', "groupby": [],
'viz_type': 'deck_grid',
'point_radius_unit': 'Pixels',
'point_radius': 'Auto',
'time_range': 'No filter',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 120,
'extruded': True,
'having': '',
'viewport': {
'longitude': -122.42066918995666,
'bearing': 155.80099696026355,
'zoom': 12.699690845482069,
'latitude': 37.7942314882596,
'pitch': 53.470800300695146,
},
'where': '',
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
} }
print('Creating Grid slice') print("Creating Grid slice")
slc = Slice( slc = Slice(
slice_name='Grid', slice_name="Grid",
viz_type='deck_grid', viz_type="deck_grid",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
slices.append(slc) slices.append(slc)
polygon_tbl = db.session.query(TBL) \ polygon_tbl = (
.filter_by(table_name='sf_population_polygons').first() db.session.query(TBL).filter_by(table_name="sf_population_polygons").first()
)
slice_data = { slice_data = {
'datasource': '11__table', "datasource": "11__table",
'viz_type': 'deck_polygon', "viz_type": "deck_polygon",
'slice_id': 41, "slice_id": 41,
'granularity_sqla': None, "granularity_sqla": None,
'time_grain_sqla': None, "time_grain_sqla": None,
'time_range': ' : ', "time_range": " : ",
'line_column': 'contour', "line_column": "contour",
'metric': None, "metric": None,
'line_type': 'json', "line_type": "json",
'mapbox_style': 'mapbox://styles/mapbox/light-v9', "mapbox_style": "mapbox://styles/mapbox/light-v9",
'viewport': { "viewport": {
'longitude': -122.43388541747726, "longitude": -122.43388541747726,
'latitude': 37.752020331384834, "latitude": 37.752020331384834,
'zoom': 11.133995608594631, "zoom": 11.133995608594631,
'bearing': 37.89506450385642, "bearing": 37.89506450385642,
'pitch': 60, "pitch": 60,
'width': 667, "width": 667,
'height': 906, "height": 906,
'altitude': 1.5, "altitude": 1.5,
'maxZoom': 20, "maxZoom": 20,
'minZoom': 0, "minZoom": 0,
'maxPitch': 60, "maxPitch": 60,
'minPitch': 0, "minPitch": 0,
'maxLatitude': 85.05113, "maxLatitude": 85.05113,
'minLatitude': -85.05113, "minLatitude": -85.05113,
}, },
'reverse_long_lat': False, "reverse_long_lat": False,
'fill_color_picker': { "fill_color_picker": {"r": 3, "g": 65, "b": 73, "a": 1},
'r': 3, "stroke_color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
'g': 65, "filled": True,
'b': 73, "stroked": False,
'a': 1, "extruded": True,
}, "point_radius_scale": 100,
'stroke_color_picker': { "js_columns": ["population", "area"],
'r': 0, "js_data_mutator": "data => data.map(d => ({\n"
'g': 122, " ...d,\n"
'b': 135, " elevation: d.extraProps.population/d.extraProps.area/10,\n"
'a': 1, "}));",
}, "js_tooltip": "",
'filled': True, "js_onclick_href": "",
'stroked': False, "where": "",
'extruded': True, "having": "",
'point_radius_scale': 100, "filters": [],
'js_columns': [
'population',
'area',
],
'js_data_mutator':
'data => data.map(d => ({\n'
' ...d,\n'
' elevation: d.extraProps.population/d.extraProps.area/10,\n'
'}));',
'js_tooltip': '',
'js_onclick_href': '',
'where': '',
'having': '',
'filters': [],
} }
print('Creating Polygon slice') print("Creating Polygon slice")
slc = Slice( slc = Slice(
slice_name='Polygons', slice_name="Polygons",
viz_type='deck_polygon', viz_type="deck_polygon",
datasource_type='table', datasource_type="table",
datasource_id=polygon_tbl.id, datasource_id=polygon_tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
@ -440,125 +384,116 @@ def load_deck_dash():
slices.append(slc) slices.append(slc)
slice_data = { slice_data = {
'datasource': '10__table', "datasource": "10__table",
'viz_type': 'deck_arc', "viz_type": "deck_arc",
'slice_id': 42, "slice_id": 42,
'granularity_sqla': None, "granularity_sqla": None,
'time_grain_sqla': None, "time_grain_sqla": None,
'time_range': ' : ', "time_range": " : ",
'start_spatial': { "start_spatial": {
'type': 'latlong', "type": "latlong",
'latCol': 'LATITUDE', "latCol": "LATITUDE",
'lonCol': 'LONGITUDE', "lonCol": "LONGITUDE",
}, },
'end_spatial': { "end_spatial": {
'type': 'latlong', "type": "latlong",
'latCol': 'LATITUDE_DEST', "latCol": "LATITUDE_DEST",
'lonCol': 'LONGITUDE_DEST', "lonCol": "LONGITUDE_DEST",
}, },
'row_limit': 5000, "row_limit": 5000,
'mapbox_style': 'mapbox://styles/mapbox/light-v9', "mapbox_style": "mapbox://styles/mapbox/light-v9",
'viewport': { "viewport": {
'altitude': 1.5, "altitude": 1.5,
'bearing': 8.546256357301871, "bearing": 8.546256357301871,
'height': 642, "height": 642,
'latitude': 44.596651438714254, "latitude": 44.596651438714254,
'longitude': -91.84340711201104, "longitude": -91.84340711201104,
'maxLatitude': 85.05113, "maxLatitude": 85.05113,
'maxPitch': 60, "maxPitch": 60,
'maxZoom': 20, "maxZoom": 20,
'minLatitude': -85.05113, "minLatitude": -85.05113,
'minPitch': 0, "minPitch": 0,
'minZoom': 0, "minZoom": 0,
'pitch': 60, "pitch": 60,
'width': 997, "width": 997,
'zoom': 2.929837070560775, "zoom": 2.929837070560775,
}, },
'color_picker': { "color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
'r': 0, "stroke_width": 1,
'g': 122, "where": "",
'b': 135, "having": "",
'a': 1, "filters": [],
},
'stroke_width': 1,
'where': '',
'having': '',
'filters': [],
} }
print('Creating Arc slice') print("Creating Arc slice")
slc = Slice( slc = Slice(
slice_name='Arcs', slice_name="Arcs",
viz_type='deck_arc', viz_type="deck_arc",
datasource_type='table', datasource_type="table",
datasource_id=db.session.query(TBL).filter_by(table_name='flights').first().id, datasource_id=db.session.query(TBL).filter_by(table_name="flights").first().id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
slices.append(slc) slices.append(slc)
slice_data = { slice_data = {
'datasource': '12__table', "datasource": "12__table",
'slice_id': 43, "slice_id": 43,
'viz_type': 'deck_path', "viz_type": "deck_path",
'time_grain_sqla': None, "time_grain_sqla": None,
'time_range': ' : ', "time_range": " : ",
'line_column': 'path_json', "line_column": "path_json",
'line_type': 'json', "line_type": "json",
'row_limit': 5000, "row_limit": 5000,
'mapbox_style': 'mapbox://styles/mapbox/light-v9', "mapbox_style": "mapbox://styles/mapbox/light-v9",
'viewport': { "viewport": {
'longitude': -122.18885402582598, "longitude": -122.18885402582598,
'latitude': 37.73671752604488, "latitude": 37.73671752604488,
'zoom': 9.51847667620428, "zoom": 9.51847667620428,
'bearing': 0, "bearing": 0,
'pitch': 0, "pitch": 0,
'width': 669, "width": 669,
'height': 1094, "height": 1094,
'altitude': 1.5, "altitude": 1.5,
'maxZoom': 20, "maxZoom": 20,
'minZoom': 0, "minZoom": 0,
'maxPitch': 60, "maxPitch": 60,
'minPitch': 0, "minPitch": 0,
'maxLatitude': 85.05113, "maxLatitude": 85.05113,
'minLatitude': -85.05113, "minLatitude": -85.05113,
}, },
'color_picker': { "color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
'r': 0, "line_width": 150,
'g': 122, "reverse_long_lat": False,
'b': 135, "js_columns": ["color"],
'a': 1, "js_data_mutator": "data => data.map(d => ({\n"
}, " ...d,\n"
'line_width': 150, " color: colors.hexToRGB(d.extraProps.color)\n"
'reverse_long_lat': False, "}));",
'js_columns': [ "js_tooltip": "",
'color', "js_onclick_href": "",
], "where": "",
'js_data_mutator': 'data => data.map(d => ({\n' "having": "",
' ...d,\n' "filters": [],
' color: colors.hexToRGB(d.extraProps.color)\n'
'}));',
'js_tooltip': '',
'js_onclick_href': '',
'where': '',
'having': '',
'filters': [],
} }
print('Creating Path slice') print("Creating Path slice")
slc = Slice( slc = Slice(
slice_name='Path', slice_name="Path",
viz_type='deck_path', viz_type="deck_path",
datasource_type='table', datasource_type="table",
datasource_id=db.session.query(TBL).filter_by(table_name='bart_lines').first().id, datasource_id=db.session.query(TBL)
.filter_by(table_name="bart_lines")
.first()
.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
slices.append(slc) slices.append(slc)
slug = 'deck' slug = "deck"
print('Creating a dashboard') print("Creating a dashboard")
title = 'deck.gl Demo' title = "deck.gl Demo"
dash = db.session.query(Dash).filter_by(slug=slug).first() dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash: if not dash:
@ -574,5 +509,5 @@ def load_deck_dash():
db.session.commit() db.session.commit()
if __name__ == '__main__': if __name__ == "__main__":
load_deck_dash() load_deck_dash()

View File

@ -26,51 +26,53 @@ from superset import db
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.utils import core as utils from superset.utils import core as utils
from .helpers import ( from .helpers import (
DATA_FOLDER, get_example_data, merge_slice, misc_dash_slices, Slice, TBL, DATA_FOLDER,
get_example_data,
merge_slice,
misc_dash_slices,
Slice,
TBL,
) )
def load_energy(): def load_energy():
"""Loads an energy related dataset to use with sankey and graphs""" """Loads an energy related dataset to use with sankey and graphs"""
tbl_name = 'energy_usage' tbl_name = "energy_usage"
data = get_example_data('energy.json.gz') data = get_example_data("energy.json.gz")
pdf = pd.read_json(data) pdf = pd.read_json(data)
pdf.to_sql( pdf.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={"source": String(255), "target": String(255), "value": Float()},
'source': String(255), index=False,
'target': String(255), )
'value': Float(),
},
index=False)
print('Creating table [wb_health_population] reference') print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = 'Energy consumption' tbl.description = "Energy consumption"
tbl.database = utils.get_or_create_main_db() tbl.database = utils.get_or_create_main_db()
if not any(col.metric_name == 'sum__value' for col in tbl.metrics): if not any(col.metric_name == "sum__value" for col in tbl.metrics):
col = str(column('value').compile(db.engine)) col = str(column("value").compile(db.engine))
tbl.metrics.append(SqlMetric( tbl.metrics.append(
metric_name='sum__value', SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
expression=f'SUM({col})', )
))
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()
slc = Slice( slc = Slice(
slice_name='Energy Sankey', slice_name="Energy Sankey",
viz_type='sankey', viz_type="sankey",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent(
"""\
{ {
"collapsed_fieldsets": "", "collapsed_fieldsets": "",
"groupby": [ "groupby": [
@ -84,17 +86,19 @@ def load_energy():
"viz_type": "sankey", "viz_type": "sankey",
"where": "" "where": ""
} }
"""), """
),
) )
misc_dash_slices.add(slc.slice_name) misc_dash_slices.add(slc.slice_name)
merge_slice(slc) merge_slice(slc)
slc = Slice( slc = Slice(
slice_name='Energy Force Layout', slice_name="Energy Force Layout",
viz_type='directed_force', viz_type="directed_force",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent(
"""\
{ {
"charge": "-500", "charge": "-500",
"collapsed_fieldsets": "", "collapsed_fieldsets": "",
@ -110,17 +114,19 @@ def load_energy():
"viz_type": "directed_force", "viz_type": "directed_force",
"where": "" "where": ""
} }
"""), """
),
) )
misc_dash_slices.add(slc.slice_name) misc_dash_slices.add(slc.slice_name)
merge_slice(slc) merge_slice(slc)
slc = Slice( slc = Slice(
slice_name='Heatmap', slice_name="Heatmap",
viz_type='heatmap', viz_type="heatmap",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=textwrap.dedent("""\ params=textwrap.dedent(
"""\
{ {
"all_columns_x": "source", "all_columns_x": "source",
"all_columns_y": "target", "all_columns_y": "target",
@ -136,7 +142,8 @@ def load_energy():
"xscale_interval": "1", "xscale_interval": "1",
"yscale_interval": "1" "yscale_interval": "1"
} }
"""), """
),
) )
misc_dash_slices.add(slc.slice_name) misc_dash_slices.add(slc.slice_name)
merge_slice(slc) merge_slice(slc)

View File

@ -24,38 +24,37 @@ from .helpers import get_example_data, TBL
def load_flights(): def load_flights():
"""Loading random time series data from a zip file in the repo""" """Loading random time series data from a zip file in the repo"""
tbl_name = 'flights' tbl_name = "flights"
data = get_example_data('flight_data.csv.gz', make_bytes=True) data = get_example_data("flight_data.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding='latin-1') pdf = pd.read_csv(data, encoding="latin-1")
# Loading airports info to join and get lat/long # Loading airports info to join and get lat/long
airports_bytes = get_example_data('airports.csv.gz', make_bytes=True) airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
airports = pd.read_csv(airports_bytes, encoding='latin-1') airports = pd.read_csv(airports_bytes, encoding="latin-1")
airports = airports.set_index('IATA_CODE') airports = airports.set_index("IATA_CODE")
pdf['ds'] = pdf.YEAR.map(str) + '-0' + pdf.MONTH.map(str) + '-0' + pdf.DAY.map(str) pdf["ds"] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
pdf.ds = pd.to_datetime(pdf.ds) pdf.ds = pd.to_datetime(pdf.ds)
del pdf['YEAR'] del pdf["YEAR"]
del pdf['MONTH'] del pdf["MONTH"]
del pdf['DAY'] del pdf["DAY"]
pdf = pdf.join(airports, on='ORIGIN_AIRPORT', rsuffix='_ORIG') pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on='DESTINATION_AIRPORT', rsuffix='_DEST') pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql( pdf.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={"ds": DateTime},
'ds': DateTime, index=False,
}, )
index=False)
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = 'Random set of flights in the US' tbl.description = "Random set of flights in the US"
tbl.database = utils.get_or_create_main_db() tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()
print('Done loading table!') print("Done loading table!")

View File

@ -27,31 +27,32 @@ from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.connector_registry import ConnectorRegistry
from superset.models import core as models from superset.models import core as models
BASE_URL = 'https://github.com/apache-superset/examples-data/blob/master/' BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
# Shortcuts # Shortcuts
DB = models.Database DB = models.Database
Slice = models.Slice Slice = models.Slice
Dash = models.Dashboard Dash = models.Dashboard
TBL = ConnectorRegistry.sources['table'] TBL = ConnectorRegistry.sources["table"]
config = app.config config = app.config
DATA_FOLDER = os.path.join(config.get('BASE_DIR'), 'data') DATA_FOLDER = os.path.join(config.get("BASE_DIR"), "data")
misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard
def update_slice_ids(layout_dict, slices): def update_slice_ids(layout_dict, slices):
charts = [ charts = [
component for component in layout_dict.values() component
if isinstance(component, dict) and component['type'] == 'CHART' for component in layout_dict.values()
if isinstance(component, dict) and component["type"] == "CHART"
] ]
sorted_charts = sorted(charts, key=lambda k: k['meta']['chartId']) sorted_charts = sorted(charts, key=lambda k: k["meta"]["chartId"])
for i, chart_component in enumerate(sorted_charts): for i, chart_component in enumerate(sorted_charts):
if i < len(slices): if i < len(slices):
chart_component['meta']['chartId'] = int(slices[i].id) chart_component["meta"]["chartId"] = int(slices[i].id)
def merge_slice(slc): def merge_slice(slc):
@ -69,9 +70,9 @@ def get_slice_json(defaults, **kwargs):
def get_example_data(filepath, is_gzip=True, make_bytes=False): def get_example_data(filepath, is_gzip=True, make_bytes=False):
content = requests.get(f'{BASE_URL}{filepath}?raw=true').content content = requests.get(f"{BASE_URL}{filepath}?raw=true").content
if is_gzip: if is_gzip:
content = zlib.decompress(content, zlib.MAX_WBITS|16) content = zlib.decompress(content, zlib.MAX_WBITS | 16)
if make_bytes: if make_bytes:
content = BytesIO(content) content = BytesIO(content)
return content return content

View File

@ -35,50 +35,49 @@ from .helpers import (
def load_long_lat_data(): def load_long_lat_data():
"""Loading lat/long data from a csv file in the repo""" """Loading lat/long data from a csv file in the repo"""
data = get_example_data('san_francisco.csv.gz', make_bytes=True) data = get_example_data("san_francisco.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding='utf-8') pdf = pd.read_csv(data, encoding="utf-8")
start = datetime.datetime.now().replace( start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
hour=0, minute=0, second=0, microsecond=0) pdf["datetime"] = [
pdf['datetime'] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf)) for i in range(len(pdf))
] ]
pdf['occupancy'] = [random.randint(1, 6) for _ in range(len(pdf))] pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf['radius_miles'] = [random.uniform(1, 3) for _ in range(len(pdf))] pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf['geohash'] = pdf[['LAT', 'LON']].apply( pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
lambda x: geohash.encode(*x), axis=1) pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf['delimited'] = pdf['LAT'].map(str).str.cat(pdf['LON'].map(str), sep=',')
pdf.to_sql( # pylint: disable=no-member pdf.to_sql( # pylint: disable=no-member
'long_lat', "long_lat",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'longitude': Float(), "longitude": Float(),
'latitude': Float(), "latitude": Float(),
'number': Float(), "number": Float(),
'street': String(100), "street": String(100),
'unit': String(10), "unit": String(10),
'city': String(50), "city": String(50),
'district': String(50), "district": String(50),
'region': String(50), "region": String(50),
'postcode': Float(), "postcode": Float(),
'id': String(100), "id": String(100),
'datetime': DateTime(), "datetime": DateTime(),
'occupancy': Float(), "occupancy": Float(),
'radius_miles': Float(), "radius_miles": Float(),
'geohash': String(12), "geohash": String(12),
'delimited': String(60), "delimited": String(60),
}, },
index=False) index=False,
print('Done loading table!') )
print('-' * 80) print("Done loading table!")
print("-" * 80)
print('Creating table reference') print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name='long_lat').first() obj = db.session.query(TBL).filter_by(table_name="long_lat").first()
if not obj: if not obj:
obj = TBL(table_name='long_lat') obj = TBL(table_name="long_lat")
obj.main_dttm_col = 'datetime' obj.main_dttm_col = "datetime"
obj.database = utils.get_or_create_main_db() obj.database = utils.get_or_create_main_db()
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
@ -86,23 +85,23 @@ def load_long_lat_data():
tbl = obj tbl = obj
slice_data = { slice_data = {
'granularity_sqla': 'day', "granularity_sqla": "day",
'since': '2014-01-01', "since": "2014-01-01",
'until': 'now', "until": "now",
'where': '', "where": "",
'viz_type': 'mapbox', "viz_type": "mapbox",
'all_columns_x': 'LON', "all_columns_x": "LON",
'all_columns_y': 'LAT', "all_columns_y": "LAT",
'mapbox_style': 'mapbox://styles/mapbox/light-v9', "mapbox_style": "mapbox://styles/mapbox/light-v9",
'all_columns': ['occupancy'], "all_columns": ["occupancy"],
'row_limit': 500000, "row_limit": 500000,
} }
print('Creating a slice') print("Creating a slice")
slc = Slice( slc = Slice(
slice_name='Mapbox Long/Lat', slice_name="Mapbox Long/Lat",
viz_type='mapbox', viz_type="mapbox",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -19,26 +19,22 @@ import textwrap
from superset import db from superset import db
from .helpers import ( from .helpers import Dash, misc_dash_slices, Slice, update_slice_ids
Dash,
misc_dash_slices,
Slice,
update_slice_ids,
)
DASH_SLUG = 'misc_charts' DASH_SLUG = "misc_charts"
def load_misc_dashboard(): def load_misc_dashboard():
"""Loading a dashboard featuring misc charts""" """Loading a dashboard featuring misc charts"""
print('Creating the dashboard') print("Creating the dashboard")
db.session.expunge_all() db.session.expunge_all()
dash = db.session.query(Dash).filter_by(slug=DASH_SLUG).first() dash = db.session.query(Dash).filter_by(slug=DASH_SLUG).first()
if not dash: if not dash:
dash = Dash() dash = Dash()
js = textwrap.dedent("""\ js = textwrap.dedent(
"""\
{ {
"CHART-BkeVbh8ANQ": { "CHART-BkeVbh8ANQ": {
"children": [], "children": [],
@ -210,17 +206,15 @@ def load_misc_dashboard():
}, },
"DASHBOARD_VERSION_KEY": "v2" "DASHBOARD_VERSION_KEY": "v2"
} }
""") """
)
pos = json.loads(js) pos = json.loads(js)
slices = ( slices = (
db.session db.session.query(Slice).filter(Slice.slice_name.in_(misc_dash_slices)).all()
.query(Slice)
.filter(Slice.slice_name.in_(misc_dash_slices))
.all()
) )
slices = sorted(slices, key=lambda x: x.id) slices = sorted(slices, key=lambda x: x.id)
update_slice_ids(pos, slices) update_slice_ids(pos, slices)
dash.dashboard_title = 'Misc Charts' dash.dashboard_title = "Misc Charts"
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG dash.slug = DASH_SLUG
dash.slices = slices dash.slices = slices

View File

@ -18,11 +18,7 @@ import json
from superset import db from superset import db
from .birth_names import load_birth_names from .birth_names import load_birth_names
from .helpers import ( from .helpers import merge_slice, misc_dash_slices, Slice
merge_slice,
misc_dash_slices,
Slice,
)
from .world_bank import load_world_bank_health_n_pop from .world_bank import load_world_bank_health_n_pop
@ -30,27 +26,30 @@ def load_multi_line():
load_world_bank_health_n_pop() load_world_bank_health_n_pop()
load_birth_names() load_birth_names()
ids = [ ids = [
row.id for row in row.id
db.session.query(Slice).filter( for row in db.session.query(Slice).filter(
Slice.slice_name.in_(['Growth Rate', 'Trends'])) Slice.slice_name.in_(["Growth Rate", "Trends"])
)
] ]
slc = Slice( slc = Slice(
datasource_type='table', # not true, but needed datasource_type="table", # not true, but needed
datasource_id=1, # cannot be empty datasource_id=1, # cannot be empty
slice_name='Multi Line', slice_name="Multi Line",
viz_type='line_multi', viz_type="line_multi",
params=json.dumps({ params=json.dumps(
'slice_name': 'Multi Line', {
'viz_type': 'line_multi', "slice_name": "Multi Line",
'line_charts': [ids[0]], "viz_type": "line_multi",
'line_charts_2': [ids[1]], "line_charts": [ids[0]],
'since': '1970', "line_charts_2": [ids[1]],
'until': '1995', "since": "1970",
'prefix_metric_with_slice_name': True, "until": "1995",
'show_legend': False, "prefix_metric_with_slice_name": True,
'x_axis_format': '%Y', "show_legend": False,
}), "x_axis_format": "%Y",
}
),
) )
misc_dash_slices.add(slc.slice_name) misc_dash_slices.add(slc.slice_name)

View File

@ -33,44 +33,45 @@ from .helpers import (
def load_multiformat_time_series(): def load_multiformat_time_series():
"""Loading time series data from a zip file in the repo""" """Loading time series data from a zip file in the repo"""
data = get_example_data('multiformat_time_series.json.gz') data = get_example_data("multiformat_time_series.json.gz")
pdf = pd.read_json(data) pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='s') pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit='s') pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.to_sql( pdf.to_sql(
'multiformat_time_series', "multiformat_time_series",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'ds': Date, "ds": Date,
'ds2': DateTime, "ds2": DateTime,
'epoch_s': BigInteger, "epoch_s": BigInteger,
'epoch_ms': BigInteger, "epoch_ms": BigInteger,
'string0': String(100), "string0": String(100),
'string1': String(100), "string1": String(100),
'string2': String(100), "string2": String(100),
'string3': String(100), "string3": String(100),
}, },
index=False) index=False,
print('Done loading table!') )
print('-' * 80) print("Done loading table!")
print('Creating table [multiformat_time_series] reference') print("-" * 80)
obj = db.session.query(TBL).filter_by(table_name='multiformat_time_series').first() print("Creating table [multiformat_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name="multiformat_time_series").first()
if not obj: if not obj:
obj = TBL(table_name='multiformat_time_series') obj = TBL(table_name="multiformat_time_series")
obj.main_dttm_col = 'ds' obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db() obj.database = utils.get_or_create_main_db()
dttm_and_expr_dict = { dttm_and_expr_dict = {
'ds': [None, None], "ds": [None, None],
'ds2': [None, None], "ds2": [None, None],
'epoch_s': ['epoch_s', None], "epoch_s": ["epoch_s", None],
'epoch_ms': ['epoch_ms', None], "epoch_ms": ["epoch_ms", None],
'string2': ['%Y%m%d-%H%M%S', None], "string2": ["%Y%m%d-%H%M%S", None],
'string1': ['%Y-%m-%d^%H:%M:%S', None], "string1": ["%Y-%m-%d^%H:%M:%S", None],
'string0': ['%Y-%m-%d %H:%M:%S.%f', None], "string0": ["%Y-%m-%d %H:%M:%S.%f", None],
'string3': ['%Y/%m/%d%H:%M:%S.%f', None], "string3": ["%Y/%m/%d%H:%M:%S.%f", None],
} }
for col in obj.columns: for col in obj.columns:
dttm_and_expr = dttm_and_expr_dict[col.column_name] dttm_and_expr = dttm_and_expr_dict[col.column_name]
@ -82,26 +83,26 @@ def load_multiformat_time_series():
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj
print('Creating Heatmap charts') print("Creating Heatmap charts")
for i, col in enumerate(tbl.columns): for i, col in enumerate(tbl.columns):
slice_data = { slice_data = {
'metrics': ['count'], "metrics": ["count"],
'granularity_sqla': col.column_name, "granularity_sqla": col.column_name,
'row_limit': config.get('ROW_LIMIT'), "row_limit": config.get("ROW_LIMIT"),
'since': '2015', "since": "2015",
'until': '2016', "until": "2016",
'where': '', "where": "",
'viz_type': 'cal_heatmap', "viz_type": "cal_heatmap",
'domain_granularity': 'month', "domain_granularity": "month",
'subdomain_granularity': 'day', "subdomain_granularity": "day",
} }
slc = Slice( slc = Slice(
slice_name=f'Calendar Heatmap multiformat {i}', slice_name=f"Calendar Heatmap multiformat {i}",
viz_type='cal_heatmap', viz_type="cal_heatmap",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
misc_dash_slices.add('Calendar Heatmap multiformat 0') misc_dash_slices.add("Calendar Heatmap multiformat 0")

View File

@ -25,29 +25,30 @@ from .helpers import TBL, get_example_data
def load_paris_iris_geojson(): def load_paris_iris_geojson():
tbl_name = 'paris_iris_mapping' tbl_name = "paris_iris_mapping"
data = get_example_data('paris_iris.json.gz') data = get_example_data("paris_iris.json.gz")
df = pd.read_json(data) df = pd.read_json(data)
df['features'] = df.features.map(json.dumps) df["features"] = df.features.map(json.dumps)
df.to_sql( df.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'color': String(255), "color": String(255),
'name': String(255), "name": String(255),
'features': Text, "features": Text,
'type': Text, "type": Text,
}, },
index=False) index=False,
print('Creating table {} reference'.format(tbl_name)) )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = 'Map of Paris' tbl.description = "Map of Paris"
tbl.database = utils.get_or_create_main_db() tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()

View File

@ -20,38 +20,30 @@ from sqlalchemy import DateTime
from superset import db from superset import db
from superset.utils import core as utils from superset.utils import core as utils
from .helpers import ( from .helpers import config, get_example_data, get_slice_json, merge_slice, Slice, TBL
config,
get_example_data,
get_slice_json,
merge_slice,
Slice,
TBL,
)
def load_random_time_series_data(): def load_random_time_series_data():
"""Loading random time series data from a zip file in the repo""" """Loading random time series data from a zip file in the repo"""
data = get_example_data('random_time_series.json.gz') data = get_example_data("random_time_series.json.gz")
pdf = pd.read_json(data) pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='s') pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.to_sql( pdf.to_sql(
'random_time_series', "random_time_series",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={"ds": DateTime},
'ds': DateTime, index=False,
}, )
index=False) print("Done loading table!")
print('Done loading table!') print("-" * 80)
print('-' * 80)
print('Creating table [random_time_series] reference') print("Creating table [random_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name='random_time_series').first() obj = db.session.query(TBL).filter_by(table_name="random_time_series").first()
if not obj: if not obj:
obj = TBL(table_name='random_time_series') obj = TBL(table_name="random_time_series")
obj.main_dttm_col = 'ds' obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db() obj.database = utils.get_or_create_main_db()
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
@ -59,22 +51,22 @@ def load_random_time_series_data():
tbl = obj tbl = obj
slice_data = { slice_data = {
'granularity_sqla': 'day', "granularity_sqla": "day",
'row_limit': config.get('ROW_LIMIT'), "row_limit": config.get("ROW_LIMIT"),
'since': '1 year ago', "since": "1 year ago",
'until': 'now', "until": "now",
'metric': 'count', "metric": "count",
'where': '', "where": "",
'viz_type': 'cal_heatmap', "viz_type": "cal_heatmap",
'domain_granularity': 'month', "domain_granularity": "month",
'subdomain_granularity': 'day', "subdomain_granularity": "day",
} }
print('Creating a slice') print("Creating a slice")
slc = Slice( slc = Slice(
slice_name='Calendar Heatmap', slice_name="Calendar Heatmap",
viz_type='cal_heatmap', viz_type="cal_heatmap",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )

View File

@ -25,29 +25,30 @@ from .helpers import TBL, get_example_data
def load_sf_population_polygons(): def load_sf_population_polygons():
tbl_name = 'sf_population_polygons' tbl_name = "sf_population_polygons"
data = get_example_data('sf_population.json.gz') data = get_example_data("sf_population.json.gz")
df = pd.read_json(data) df = pd.read_json(data)
df['contour'] = df.contour.map(json.dumps) df["contour"] = df.contour.map(json.dumps)
df.to_sql( df.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'zipcode': BigInteger, "zipcode": BigInteger,
'population': BigInteger, "population": BigInteger,
'contour': Text, "contour": Text,
'area': BigInteger, "area": BigInteger,
}, },
index=False) index=False,
print('Creating table {} reference'.format(tbl_name)) )
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = 'Population density of San Francisco' tbl.description = "Population density of San Francisco"
tbl.database = utils.get_or_create_main_db() tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()

View File

@ -44,7 +44,7 @@ def load_tabbed_dashboard():
"""Creating a tabbed dashboard""" """Creating a tabbed dashboard"""
print("Creating a dashboard with nested tabs") print("Creating a dashboard with nested tabs")
slug = 'tabbed_dash' slug = "tabbed_dash"
dash = db.session.query(Dash).filter_by(slug=slug).first() dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash: if not dash:
@ -53,12 +53,13 @@ def load_tabbed_dashboard():
# reuse charts in "World's Bank Data and create # reuse charts in "World's Bank Data and create
# new dashboard with nested tabs # new dashboard with nested tabs
tabbed_dash_slices = set() tabbed_dash_slices = set()
tabbed_dash_slices.add('Region Filter') tabbed_dash_slices.add("Region Filter")
tabbed_dash_slices.add('Growth Rate') tabbed_dash_slices.add("Growth Rate")
tabbed_dash_slices.add('Treemap') tabbed_dash_slices.add("Treemap")
tabbed_dash_slices.add('Box plot') tabbed_dash_slices.add("Box plot")
js = textwrap.dedent("""\ js = textwrap.dedent(
"""\
{ {
"CHART-c0EjR-OZ0n": { "CHART-c0EjR-OZ0n": {
"children": [], "children": [],
@ -337,12 +338,11 @@ def load_tabbed_dashboard():
"type": "TABS" "type": "TABS"
} }
} }
""") """
)
pos = json.loads(js) pos = json.loads(js)
slices = [ slices = [
db.session.query(Slice) db.session.query(Slice).filter_by(slice_name=name).first()
.filter_by(slice_name=name)
.first()
for name in tabbed_dash_slices for name in tabbed_dash_slices
] ]
@ -350,7 +350,7 @@ def load_tabbed_dashboard():
update_slice_ids(pos, slices) update_slice_ids(pos, slices)
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slices = slices dash.slices = slices
dash.dashboard_title = 'Tabbed Dashboard' dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug dash.slug = slug
db.session.merge(dash) db.session.merge(dash)

View File

@ -38,32 +38,34 @@ from .helpers import (
def load_unicode_test_data(): def load_unicode_test_data():
"""Loading unicode test dataset from a csv file in the repo""" """Loading unicode test dataset from a csv file in the repo"""
data = get_example_data( data = get_example_data(
'unicode_utf8_unixnl_test.csv', is_gzip=False, make_bytes=True) "unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
df = pd.read_csv(data, encoding='utf-8') )
df = pd.read_csv(data, encoding="utf-8")
# generate date/numeric data # generate date/numeric data
df['dttm'] = datetime.datetime.now().date() df["dttm"] = datetime.datetime.now().date()
df['value'] = [random.randint(1, 100) for _ in range(len(df))] df["value"] = [random.randint(1, 100) for _ in range(len(df))]
df.to_sql( # pylint: disable=no-member df.to_sql( # pylint: disable=no-member
'unicode_test', "unicode_test",
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=500, chunksize=500,
dtype={ dtype={
'phrase': String(500), "phrase": String(500),
'short_phrase': String(10), "short_phrase": String(10),
'with_missing': String(100), "with_missing": String(100),
'dttm': Date(), "dttm": Date(),
'value': Float(), "value": Float(),
}, },
index=False) index=False,
print('Done loading table!') )
print('-' * 80) print("Done loading table!")
print("-" * 80)
print('Creating table [unicode_test] reference') print("Creating table [unicode_test] reference")
obj = db.session.query(TBL).filter_by(table_name='unicode_test').first() obj = db.session.query(TBL).filter_by(table_name="unicode_test").first()
if not obj: if not obj:
obj = TBL(table_name='unicode_test') obj = TBL(table_name="unicode_test")
obj.main_dttm_col = 'dttm' obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db() obj.database = utils.get_or_create_main_db()
db.session.merge(obj) db.session.merge(obj)
db.session.commit() db.session.commit()
@ -71,37 +73,33 @@ def load_unicode_test_data():
tbl = obj tbl = obj
slice_data = { slice_data = {
'granularity_sqla': 'dttm', "granularity_sqla": "dttm",
'groupby': [], "groupby": [],
'metric': 'sum__value', "metric": "sum__value",
'row_limit': config.get('ROW_LIMIT'), "row_limit": config.get("ROW_LIMIT"),
'since': '100 years ago', "since": "100 years ago",
'until': 'now', "until": "now",
'where': '', "where": "",
'viz_type': 'word_cloud', "viz_type": "word_cloud",
'size_from': '10', "size_from": "10",
'series': 'short_phrase', "series": "short_phrase",
'size_to': '70', "size_to": "70",
'rotation': 'square', "rotation": "square",
'limit': '100', "limit": "100",
} }
print('Creating a slice') print("Creating a slice")
slc = Slice( slc = Slice(
slice_name='Unicode Cloud', slice_name="Unicode Cloud",
viz_type='word_cloud', viz_type="word_cloud",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json(slice_data), params=get_slice_json(slice_data),
) )
merge_slice(slc) merge_slice(slc)
print('Creating a dashboard') print("Creating a dashboard")
dash = ( dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first()
db.session.query(Dash)
.filter_by(dashboard_title='Unicode Test')
.first()
)
if not dash: if not dash:
dash = Dash() dash = Dash()
@ -145,11 +143,11 @@ def load_unicode_test_data():
"DASHBOARD_VERSION_KEY": "v2" "DASHBOARD_VERSION_KEY": "v2"
} }
""" """
dash.dashboard_title = 'Unicode Test' dash.dashboard_title = "Unicode Test"
pos = json.loads(js) pos = json.loads(js)
update_slice_ids(pos, [slc]) update_slice_ids(pos, [slc])
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slug = 'unicode-test' dash.slug = "unicode-test"
dash.slices = [slc] dash.slices = [slc]
db.session.merge(dash) db.session.merge(dash)
db.session.commit() db.session.commit()

View File

@ -43,232 +43,270 @@ from .helpers import (
def load_world_bank_health_n_pop(): def load_world_bank_health_n_pop():
"""Loads the world bank health dataset, slices and a dashboard""" """Loads the world bank health dataset, slices and a dashboard"""
tbl_name = 'wb_health_population' tbl_name = "wb_health_population"
data = get_example_data('countries.json.gz') data = get_example_data("countries.json.gz")
pdf = pd.read_json(data) pdf = pd.read_json(data)
pdf.columns = [col.replace('.', '_') for col in pdf.columns] pdf.columns = [col.replace(".", "_") for col in pdf.columns]
pdf.year = pd.to_datetime(pdf.year) pdf.year = pd.to_datetime(pdf.year)
pdf.to_sql( pdf.to_sql(
tbl_name, tbl_name,
db.engine, db.engine,
if_exists='replace', if_exists="replace",
chunksize=50, chunksize=50,
dtype={ dtype={
'year': DateTime(), "year": DateTime(),
'country_code': String(3), "country_code": String(3),
'country_name': String(255), "country_name": String(255),
'region': String(255), "region": String(255),
}, },
index=False) index=False,
)
print('Creating table [wb_health_population] reference') print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first() tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl: if not tbl:
tbl = TBL(table_name=tbl_name) tbl = TBL(table_name=tbl_name)
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md')) tbl.description = utils.readfile(os.path.join(DATA_FOLDER, "countries.md"))
tbl.main_dttm_col = 'year' tbl.main_dttm_col = "year"
tbl.database = utils.get_or_create_main_db() tbl.database = utils.get_or_create_main_db()
tbl.filter_select_enabled = True tbl.filter_select_enabled = True
metrics = [ metrics = [
'sum__SP_POP_TOTL', 'sum__SH_DYN_AIDS', 'sum__SH_DYN_AIDS', "sum__SP_POP_TOTL",
'sum__SP_RUR_TOTL_ZS', 'sum__SP_DYN_LE00_IN', 'sum__SP_RUR_TOTL' "sum__SH_DYN_AIDS",
"sum__SH_DYN_AIDS",
"sum__SP_RUR_TOTL_ZS",
"sum__SP_DYN_LE00_IN",
"sum__SP_RUR_TOTL",
] ]
for m in metrics: for m in metrics:
if not any(col.metric_name == m for col in tbl.metrics): if not any(col.metric_name == m for col in tbl.metrics):
aggr_func = m[:3] aggr_func = m[:3]
col = str(column(m[5:]).compile(db.engine)) col = str(column(m[5:]).compile(db.engine))
tbl.metrics.append(SqlMetric( tbl.metrics.append(
metric_name=m, SqlMetric(metric_name=m, expression=f"{aggr_func}({col})")
expression=f'{aggr_func}({col})', )
))
db.session.merge(tbl) db.session.merge(tbl)
db.session.commit() db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()
defaults = { defaults = {
'compare_lag': '10', "compare_lag": "10",
'compare_suffix': 'o10Y', "compare_suffix": "o10Y",
'limit': '25', "limit": "25",
'granularity_sqla': 'year', "granularity_sqla": "year",
'groupby': [], "groupby": [],
'metric': 'sum__SP_POP_TOTL', "metric": "sum__SP_POP_TOTL",
'metrics': ['sum__SP_POP_TOTL'], "metrics": ["sum__SP_POP_TOTL"],
'row_limit': config.get('ROW_LIMIT'), "row_limit": config.get("ROW_LIMIT"),
'since': '2014-01-01', "since": "2014-01-01",
'until': '2014-01-02', "until": "2014-01-02",
'time_range': '2014-01-01 : 2014-01-02', "time_range": "2014-01-01 : 2014-01-02",
'where': '', "where": "",
'markup_type': 'markdown', "markup_type": "markdown",
'country_fieldtype': 'cca3', "country_fieldtype": "cca3",
'secondary_metric': 'sum__SP_POP_TOTL', "secondary_metric": "sum__SP_POP_TOTL",
'entity': 'country_code', "entity": "country_code",
'show_bubbles': True, "show_bubbles": True,
} }
print('Creating slices') print("Creating slices")
slices = [ slices = [
Slice( Slice(
slice_name='Region Filter', slice_name="Region Filter",
viz_type='filter_box', viz_type="filter_box",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='filter_box', viz_type="filter_box",
date_filter=False, date_filter=False,
filter_configs=[ filter_configs=[
{ {
'asc': False, "asc": False,
'clearable': True, "clearable": True,
'column': 'region', "column": "region",
'key': '2s98dfu', "key": "2s98dfu",
'metric': 'sum__SP_POP_TOTL', "metric": "sum__SP_POP_TOTL",
'multiple': True, "multiple": True,
}, {
'asc': False,
'clearable': True,
'key': 'li3j2lk',
'column': 'country_name',
'metric': 'sum__SP_POP_TOTL',
'multiple': True,
}, },
])), {
"asc": False,
"clearable": True,
"key": "li3j2lk",
"column": "country_name",
"metric": "sum__SP_POP_TOTL",
"multiple": True,
},
],
),
),
Slice( Slice(
slice_name="World's Population", slice_name="World's Population",
viz_type='big_number', viz_type="big_number",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='2000', since="2000",
viz_type='big_number', viz_type="big_number",
compare_lag='10', compare_lag="10",
metric='sum__SP_POP_TOTL', metric="sum__SP_POP_TOTL",
compare_suffix='over 10Y')), compare_suffix="over 10Y",
),
),
Slice( Slice(
slice_name='Most Populated Countries', slice_name="Most Populated Countries",
viz_type='table', viz_type="table",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='table', viz_type="table",
metrics=['sum__SP_POP_TOTL'], metrics=["sum__SP_POP_TOTL"],
groupby=['country_name'])), groupby=["country_name"],
),
),
Slice( Slice(
slice_name='Growth Rate', slice_name="Growth Rate",
viz_type='line', viz_type="line",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='line', viz_type="line",
since='1960-01-01', since="1960-01-01",
metrics=['sum__SP_POP_TOTL'], metrics=["sum__SP_POP_TOTL"],
num_period_compare='10', num_period_compare="10",
groupby=['country_name'])), groupby=["country_name"],
),
),
Slice( Slice(
slice_name='% Rural', slice_name="% Rural",
viz_type='world_map', viz_type="world_map",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='world_map', viz_type="world_map",
metric='sum__SP_RUR_TOTL_ZS', metric="sum__SP_RUR_TOTL_ZS",
num_period_compare='10')), num_period_compare="10",
),
),
Slice( Slice(
slice_name='Life Expectancy VS Rural %', slice_name="Life Expectancy VS Rural %",
viz_type='bubble', viz_type="bubble",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='bubble', viz_type="bubble",
since='2011-01-01', since="2011-01-01",
until='2011-01-02', until="2011-01-02",
series='region', series="region",
limit=0, limit=0,
entity='country_name', entity="country_name",
x='sum__SP_RUR_TOTL_ZS', x="sum__SP_RUR_TOTL_ZS",
y='sum__SP_DYN_LE00_IN', y="sum__SP_DYN_LE00_IN",
size='sum__SP_POP_TOTL', size="sum__SP_POP_TOTL",
max_bubble_size='50', max_bubble_size="50",
filters=[{ filters=[
'col': 'country_code', {
'val': [ "col": "country_code",
'TCA', 'MNP', 'DMA', 'MHL', 'MCO', 'SXM', 'CYM', "val": [
'TUV', 'IMY', 'KNA', 'ASM', 'ADO', 'AMA', 'PLW', "TCA",
], "MNP",
'op': 'not in'}], "DMA",
)), "MHL",
"MCO",
"SXM",
"CYM",
"TUV",
"IMY",
"KNA",
"ASM",
"ADO",
"AMA",
"PLW",
],
"op": "not in",
}
],
),
),
Slice( Slice(
slice_name='Rural Breakdown', slice_name="Rural Breakdown",
viz_type='sunburst', viz_type="sunburst",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
viz_type='sunburst', viz_type="sunburst",
groupby=['region', 'country_name'], groupby=["region", "country_name"],
secondary_metric='sum__SP_RUR_TOTL', secondary_metric="sum__SP_RUR_TOTL",
since='2011-01-01', since="2011-01-01",
until='2011-01-01')), until="2011-01-01",
),
),
Slice( Slice(
slice_name="World's Pop Growth", slice_name="World's Pop Growth",
viz_type='area', viz_type="area",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='1960-01-01', since="1960-01-01",
until='now', until="now",
viz_type='area', viz_type="area",
groupby=['region'])), groupby=["region"],
),
),
Slice( Slice(
slice_name='Box plot', slice_name="Box plot",
viz_type='box_plot', viz_type="box_plot",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='1960-01-01', since="1960-01-01",
until='now', until="now",
whisker_options='Min/max (no outliers)', whisker_options="Min/max (no outliers)",
x_ticks_layout='staggered', x_ticks_layout="staggered",
viz_type='box_plot', viz_type="box_plot",
groupby=['region'])), groupby=["region"],
),
),
Slice( Slice(
slice_name='Treemap', slice_name="Treemap",
viz_type='treemap', viz_type="treemap",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='1960-01-01', since="1960-01-01",
until='now', until="now",
viz_type='treemap', viz_type="treemap",
metrics=['sum__SP_POP_TOTL'], metrics=["sum__SP_POP_TOTL"],
groupby=['region', 'country_code'])), groupby=["region", "country_code"],
),
),
Slice( Slice(
slice_name='Parallel Coordinates', slice_name="Parallel Coordinates",
viz_type='para', viz_type="para",
datasource_type='table', datasource_type="table",
datasource_id=tbl.id, datasource_id=tbl.id,
params=get_slice_json( params=get_slice_json(
defaults, defaults,
since='2011-01-01', since="2011-01-01",
until='2011-01-01', until="2011-01-01",
viz_type='para', viz_type="para",
limit=100, limit=100,
metrics=[ metrics=["sum__SP_POP_TOTL", "sum__SP_RUR_TOTL_ZS", "sum__SH_DYN_AIDS"],
'sum__SP_POP_TOTL', secondary_metric="sum__SP_POP_TOTL",
'sum__SP_RUR_TOTL_ZS', series="country_name",
'sum__SH_DYN_AIDS'], ),
secondary_metric='sum__SP_POP_TOTL', ),
series='country_name')),
] ]
misc_dash_slices.add(slices[-1].slice_name) misc_dash_slices.add(slices[-1].slice_name)
for slc in slices: for slc in slices:
@ -276,12 +314,13 @@ def load_world_bank_health_n_pop():
print("Creating a World's Health Bank dashboard") print("Creating a World's Health Bank dashboard")
dash_name = "World's Bank Data" dash_name = "World's Bank Data"
slug = 'world_health' slug = "world_health"
dash = db.session.query(Dash).filter_by(slug=slug).first() dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash: if not dash:
dash = Dash() dash = Dash()
js = textwrap.dedent("""\ js = textwrap.dedent(
"""\
{ {
"CHART-36bfc934": { "CHART-36bfc934": {
"children": [], "children": [],
@ -497,7 +536,8 @@ def load_world_bank_health_n_pop():
}, },
"DASHBOARD_VERSION_KEY": "v2" "DASHBOARD_VERSION_KEY": "v2"
} }
""") """
)
pos = json.loads(js) pos = json.loads(js)
update_slice_ids(pos, slices) update_slice_ids(pos, slices)

View File

@ -36,7 +36,7 @@ INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100 INFER_COL_TYPES_SAMPLE_SIZE = 100
def dedup(l, suffix='__', case_sensitive=True): def dedup(l, suffix="__", case_sensitive=True):
"""De-duplicates a list of string by suffixing a counter """De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns Always returns the same number of entries as provided, and always returns
@ -44,7 +44,9 @@ def dedup(l, suffix='__', case_sensitive=True):
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar']))) >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
foo,bar,bar__1,bar__2,Bar foo,bar,bar__1,bar__2,Bar
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))) >>> print(
','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))
)
foo,bar,bar__1,bar__2,Bar__3 foo,bar,bar__1,bar__2,Bar__3
""" """
new_l = [] new_l = []
@ -63,18 +65,18 @@ def dedup(l, suffix='__', case_sensitive=True):
class SupersetDataFrame(object): class SupersetDataFrame(object):
# Mapping numpy dtype.char to generic database types # Mapping numpy dtype.char to generic database types
type_map = { type_map = {
'b': 'BOOL', # boolean "b": "BOOL", # boolean
'i': 'INT', # (signed) integer "i": "INT", # (signed) integer
'u': 'INT', # unsigned integer "u": "INT", # unsigned integer
'l': 'INT', # 64bit integer "l": "INT", # 64bit integer
'f': 'FLOAT', # floating-point "f": "FLOAT", # floating-point
'c': 'FLOAT', # complex-floating point "c": "FLOAT", # complex-floating point
'm': None, # timedelta "m": None, # timedelta
'M': 'DATETIME', # datetime "M": "DATETIME", # datetime
'O': 'OBJECT', # (Python) objects "O": "OBJECT", # (Python) objects
'S': 'BYTE', # (byte-)string "S": "BYTE", # (byte-)string
'U': 'STRING', # Unicode "U": "STRING", # Unicode
'V': None, # raw data (void) "V": None, # raw data (void)
} }
def __init__(self, data, cursor_description, db_engine_spec): def __init__(self, data, cursor_description, db_engine_spec):
@ -85,8 +87,7 @@ class SupersetDataFrame(object):
self.column_names = dedup(column_names) self.column_names = dedup(column_names)
data = data or [] data = data or []
self.df = ( self.df = pd.DataFrame(list(data), columns=self.column_names).infer_objects()
pd.DataFrame(list(data), columns=self.column_names).infer_objects())
self._type_dict = {} self._type_dict = {}
try: try:
@ -106,9 +107,13 @@ class SupersetDataFrame(object):
@property @property
def data(self): def data(self):
# work around for https://github.com/pandas-dev/pandas/issues/18372 # work around for https://github.com/pandas-dev/pandas/issues/18372
data = [dict((k, _maybe_box_datetimelike(v)) data = [
for k, v in zip(self.df.columns, np.atleast_1d(row))) dict(
for row in self.df.values] (k, _maybe_box_datetimelike(v))
for k, v in zip(self.df.columns, np.atleast_1d(row))
)
for row in self.df.values
]
for d in data: for d in data:
for k, v in list(d.items()): for k, v in list(d.items()):
# if an int is too big for Java Script to handle # if an int is too big for Java Script to handle
@ -123,7 +128,7 @@ class SupersetDataFrame(object):
"""Given a numpy dtype, Returns a generic database type""" """Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype): if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind) return cls.type_map.get(dtype.kind)
elif hasattr(dtype, 'char'): elif hasattr(dtype, "char"):
return cls.type_map.get(dtype.char) return cls.type_map.get(dtype.char)
@classmethod @classmethod
@ -141,10 +146,9 @@ class SupersetDataFrame(object):
@staticmethod @staticmethod
def is_date(np_dtype, db_type_str): def is_date(np_dtype, db_type_str):
def looks_daty(s): def looks_daty(s):
if isinstance(s, str): if isinstance(s, str):
return any([s.lower().startswith(ss) for ss in ('time', 'date')]) return any([s.lower().startswith(ss) for ss in ("time", "date")])
return False return False
if looks_daty(db_type_str): if looks_daty(db_type_str):
@ -157,20 +161,23 @@ class SupersetDataFrame(object):
def is_dimension(cls, dtype, column_name): def is_dimension(cls, dtype, column_name):
if cls.is_id(column_name): if cls.is_id(column_name):
return False return False
return dtype.name in ('object', 'bool') return dtype.name in ("object", "bool")
@classmethod @classmethod
def is_id(cls, column_name): def is_id(cls, column_name):
return column_name.startswith('id') or column_name.endswith('id') return column_name.startswith("id") or column_name.endswith("id")
@classmethod @classmethod
def agg_func(cls, dtype, column_name): def agg_func(cls, dtype, column_name):
# consider checking for key substring too. # consider checking for key substring too.
if cls.is_id(column_name): if cls.is_id(column_name):
return 'count_distinct' return "count_distinct"
if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and if (
np.issubdtype(dtype, np.number)): hasattr(dtype, "type")
return 'sum' and issubclass(dtype.type, np.generic)
and np.issubdtype(dtype, np.number)
):
return "sum"
return None return None
@property @property
@ -188,42 +195,36 @@ class SupersetDataFrame(object):
if sample_size: if sample_size:
sample = self.df.sample(sample_size) sample = self.df.sample(sample_size)
for col in self.df.dtypes.keys(): for col in self.df.dtypes.keys():
db_type_str = ( db_type_str = self._type_dict.get(col) or self.db_type(self.df.dtypes[col])
self._type_dict.get(col) or
self.db_type(self.df.dtypes[col])
)
column = { column = {
'name': col, "name": col,
'agg': self.agg_func(self.df.dtypes[col], col), "agg": self.agg_func(self.df.dtypes[col], col),
'type': db_type_str, "type": db_type_str,
'is_date': self.is_date(self.df.dtypes[col], db_type_str), "is_date": self.is_date(self.df.dtypes[col], db_type_str),
'is_dim': self.is_dimension(self.df.dtypes[col], col), "is_dim": self.is_dimension(self.df.dtypes[col], col),
} }
if not db_type_str or db_type_str.upper() == 'OBJECT': if not db_type_str or db_type_str.upper() == "OBJECT":
v = sample[col].iloc[0] if not sample[col].empty else None v = sample[col].iloc[0] if not sample[col].empty else None
if isinstance(v, str): if isinstance(v, str):
column['type'] = 'STRING' column["type"] = "STRING"
elif isinstance(v, int): elif isinstance(v, int):
column['type'] = 'INT' column["type"] = "INT"
elif isinstance(v, float): elif isinstance(v, float):
column['type'] = 'FLOAT' column["type"] = "FLOAT"
elif isinstance(v, (datetime, date)): elif isinstance(v, (datetime, date)):
column['type'] = 'DATETIME' column["type"] = "DATETIME"
column['is_date'] = True column["is_date"] = True
column['is_dim'] = False column["is_dim"] = False
# check if encoded datetime # check if encoded datetime
if ( if (
column['type'] == 'STRING' and column["type"] == "STRING"
self.datetime_conversion_rate(sample[col]) > and self.datetime_conversion_rate(sample[col])
INFER_COL_TYPES_THRESHOLD): > INFER_COL_TYPES_THRESHOLD
column.update({ ):
'is_date': True, column.update({"is_date": True, "is_dim": False, "agg": None})
'is_dim': False,
'agg': None,
})
# 'agg' is optional attribute # 'agg' is optional attribute
if not column['agg']: if not column["agg"]:
column.pop('agg', None) column.pop("agg", None)
columns.append(column) columns.append(column)
return columns return columns

View File

@ -39,11 +39,14 @@ from superset.db_engine_specs.base import BaseEngineSpec
engines: Dict[str, Type[BaseEngineSpec]] = {} engines: Dict[str, Type[BaseEngineSpec]] = {}
for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore
imported_module = import_module('.' + name, package=__name__) imported_module = import_module("." + name, package=__name__)
for i in dir(imported_module): for i in dir(imported_module):
attribute = getattr(imported_module, i) attribute = getattr(imported_module, i)
if inspect.isclass(attribute) and issubclass(attribute, BaseEngineSpec) \ if (
and attribute.engine != '': inspect.isclass(attribute)
and issubclass(attribute, BaseEngineSpec)
and attribute.engine != ""
):
engines[attribute.engine] = attribute engines[attribute.engine] = attribute

View File

@ -19,37 +19,36 @@ from superset.db_engine_specs.base import BaseEngineSpec
class AthenaEngineSpec(BaseEngineSpec): class AthenaEngineSpec(BaseEngineSpec):
engine = 'awsathena' engine = "awsathena"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))", "PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))",
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))", "PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))",
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))", "PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))",
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))", "P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))",
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))", "P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))",
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))", "P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))",
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))", "P0.25Y": "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))", "P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))",
'P1W/1970-01-03T00:00:00Z': "date_add('day', 5, date_trunc('week', \ "P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', \
date_add('day', 1, CAST({col} AS TIMESTAMP))))", date_add('day', 1, CAST({col} AS TIMESTAMP))))",
'1969-12-28T00:00:00Z/P1W': "date_add('day', -1, date_trunc('week', \ "1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', \
date_add('day', 1, CAST({col} AS TIMESTAMP))))", date_add('day', 1, CAST({col} AS TIMESTAMP))))",
} }
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10]) return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP': if tt == "TIMESTAMP":
return "from_iso8601_timestamp('{}')".format(dttm.isoformat()) return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
return ("CAST ('{}' AS TIMESTAMP)" return "CAST ('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
.format(dttm.strftime('%Y-%m-%d %H:%M:%S')))
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return 'from_unixtime({col})' return "from_unixtime({col})"
@staticmethod @staticmethod
def mutate_label(label): def mutate_label(label):

View File

@ -37,7 +37,7 @@ from werkzeug.utils import secure_filename
from superset import app, db, sql_parse from superset import app, db, sql_parse
from superset.utils import core as utils from superset.utils import core as utils
Grain = namedtuple('Grain', 'name label function duration') Grain = namedtuple("Grain", "name label function duration")
config = app.config config = app.config
@ -46,23 +46,23 @@ QueryStatus = utils.QueryStatus
config = app.config config = app.config
builtin_time_grains = { builtin_time_grains = {
None: 'Time Column', None: "Time Column",
'PT1S': 'second', "PT1S": "second",
'PT1M': 'minute', "PT1M": "minute",
'PT5M': '5 minute', "PT5M": "5 minute",
'PT10M': '10 minute', "PT10M": "10 minute",
'PT15M': '15 minute', "PT15M": "15 minute",
'PT0.5H': 'half hour', "PT0.5H": "half hour",
'PT1H': 'hour', "PT1H": "hour",
'P1D': 'day', "P1D": "day",
'P1W': 'week', "P1W": "week",
'P1M': 'month', "P1M": "month",
'P0.25Y': 'quarter', "P0.25Y": "quarter",
'P1Y': 'year', "P1Y": "year",
'1969-12-28T00:00:00Z/P1W': 'week_start_sunday', "1969-12-28T00:00:00Z/P1W": "week_start_sunday",
'1969-12-29T00:00:00Z/P1W': 'week_start_monday', "1969-12-29T00:00:00Z/P1W": "week_start_monday",
'P1W/1970-01-03T00:00:00Z': 'week_ending_saturday', "P1W/1970-01-03T00:00:00Z": "week_ending_saturday",
'P1W/1970-01-04T00:00:00Z': 'week_ending_sunday', "P1W/1970-01-04T00:00:00Z": "week_ending_sunday",
} }
@ -81,14 +81,15 @@ class TimestampExpression(ColumnClause):
@compiles(TimestampExpression) @compiles(TimestampExpression)
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw): def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
return element.name.replace('{col}', compiler.process(element.col, **kw)) return element.name.replace("{col}", compiler.process(element.col, **kw))
class LimitMethod(object): class LimitMethod(object):
"""Enum the ways that limits can be applied""" """Enum the ways that limits can be applied"""
FETCH_MANY = 'fetch_many'
WRAP_SQL = 'wrap_sql' FETCH_MANY = "fetch_many"
FORCE_LIMIT = 'force_limit' WRAP_SQL = "wrap_sql"
FORCE_LIMIT = "force_limit"
def create_time_grains_tuple(time_grains, time_grain_functions, blacklist): def create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
@ -104,7 +105,7 @@ def create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
class BaseEngineSpec(object): class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations""" """Abstract class for database engine specific configurations"""
engine = 'base' # str as defined in sqlalchemy.engine.engine engine = "base" # str as defined in sqlalchemy.engine.engine
time_grain_functions: Dict[Optional[str], str] = {} time_grain_functions: Dict[Optional[str], str] = {}
time_groupby_inline = False time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT limit_method = LimitMethod.FORCE_LIMIT
@ -118,8 +119,9 @@ class BaseEngineSpec(object):
try_remove_schema_from_table_name = True try_remove_schema_from_table_name = True
@classmethod @classmethod
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], def get_timestamp_expr(
time_grain: Optional[str]) -> TimestampExpression: cls, col: ColumnClause, pdf: Optional[str], time_grain: Optional[str]
) -> TimestampExpression:
""" """
Construct a TimeExpression to be used in a SQLAlchemy query. Construct a TimeExpression to be used in a SQLAlchemy query.
@ -132,25 +134,26 @@ class BaseEngineSpec(object):
time_expr = cls.time_grain_functions.get(time_grain) time_expr = cls.time_grain_functions.get(time_grain)
if not time_expr: if not time_expr:
raise NotImplementedError( raise NotImplementedError(
f'No grain spec for {time_grain} for database {cls.engine}') f"No grain spec for {time_grain} for database {cls.engine}"
)
else: else:
time_expr = '{col}' time_expr = "{col}"
# if epoch, translate to DATE using db specific conf # if epoch, translate to DATE using db specific conf
if pdf == 'epoch_s': if pdf == "epoch_s":
time_expr = time_expr.replace('{col}', cls.epoch_to_dttm()) time_expr = time_expr.replace("{col}", cls.epoch_to_dttm())
elif pdf == 'epoch_ms': elif pdf == "epoch_ms":
time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm()) time_expr = time_expr.replace("{col}", cls.epoch_ms_to_dttm())
return TimestampExpression(time_expr, col, type_=DateTime) return TimestampExpression(time_expr, col, type_=DateTime)
@classmethod @classmethod
def get_time_grains(cls): def get_time_grains(cls):
blacklist = config.get('TIME_GRAIN_BLACKLIST', []) blacklist = config.get("TIME_GRAIN_BLACKLIST", [])
grains = builtin_time_grains.copy() grains = builtin_time_grains.copy()
grains.update(config.get('TIME_GRAIN_ADDONS', {})) grains.update(config.get("TIME_GRAIN_ADDONS", {}))
grain_functions = cls.time_grain_functions.copy() grain_functions = cls.time_grain_functions.copy()
grain_addon_functions = config.get('TIME_GRAIN_ADDON_FUNCTIONS', {}) grain_addon_functions = config.get("TIME_GRAIN_ADDON_FUNCTIONS", {})
grain_functions.update(grain_addon_functions.get(cls.engine, {})) grain_functions.update(grain_addon_functions.get(cls.engine, {}))
return create_time_grains_tuple(grains, grain_functions, blacklist) return create_time_grains_tuple(grains, grain_functions, blacklist)
@ -169,9 +172,9 @@ class BaseEngineSpec(object):
return cursor.fetchall() return cursor.fetchall()
@classmethod @classmethod
def expand_data(cls, def expand_data(
columns: List[dict], cls, columns: List[dict], data: List[dict]
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]: ) -> Tuple[List[dict], List[dict], List[dict]]:
return columns, data, [] return columns, data, []
@classmethod @classmethod
@ -189,7 +192,7 @@ class BaseEngineSpec(object):
@classmethod @classmethod
def epoch_ms_to_dttm(cls): def epoch_ms_to_dttm(cls):
return cls.epoch_to_dttm().replace('{col}', '({col}/1000)') return cls.epoch_to_dttm().replace("{col}", "({col}/1000)")
@classmethod @classmethod
def get_datatype(cls, type_code): def get_datatype(cls, type_code):
@ -205,12 +208,10 @@ class BaseEngineSpec(object):
def apply_limit_to_sql(cls, sql, limit, database): def apply_limit_to_sql(cls, sql, limit, database):
"""Alters the SQL statement to apply a LIMIT clause""" """Alters the SQL statement to apply a LIMIT clause"""
if cls.limit_method == LimitMethod.WRAP_SQL: if cls.limit_method == LimitMethod.WRAP_SQL:
sql = sql.strip('\t\n ;') sql = sql.strip("\t\n ;")
qry = ( qry = (
select('*') select("*")
.select_from( .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry"))
TextAsFrom(text(sql), ['*']).alias('inner_qry'),
)
.limit(limit) .limit(limit)
) )
return database.compile_sqla_query(qry) return database.compile_sqla_query(qry)
@ -235,10 +236,11 @@ class BaseEngineSpec(object):
:param kwargs: params to be passed to DataFrame.read_csv :param kwargs: params to be passed to DataFrame.read_csv
:return: Pandas DataFrame containing data from csv :return: Pandas DataFrame containing data from csv
""" """
kwargs['filepath_or_buffer'] = \ kwargs["filepath_or_buffer"] = (
config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer'] config["UPLOAD_FOLDER"] + kwargs["filepath_or_buffer"]
kwargs['encoding'] = 'utf-8' )
kwargs['iterator'] = True kwargs["encoding"] = "utf-8"
kwargs["iterator"] = True
chunks = pd.read_csv(**kwargs) chunks = pd.read_csv(**kwargs)
df = pd.concat(chunk for chunk in chunks) df = pd.concat(chunk for chunk in chunks)
return df return df
@ -260,39 +262,42 @@ class BaseEngineSpec(object):
:param form: Parameters defining how to process data :param form: Parameters defining how to process data
:param table: Metadata of new table to be created :param table: Metadata of new table to be created
""" """
def _allowed_file(filename: str) -> bool: def _allowed_file(filename: str) -> bool:
# Only allow specific file extensions as specified in the config # Only allow specific file extensions as specified in the config
extension = os.path.splitext(filename)[1] extension = os.path.splitext(filename)[1]
return extension is not None and extension[1:] in config['ALLOWED_EXTENSIONS'] return (
extension is not None and extension[1:] in config["ALLOWED_EXTENSIONS"]
)
filename = secure_filename(form.csv_file.data.filename) filename = secure_filename(form.csv_file.data.filename)
if not _allowed_file(filename): if not _allowed_file(filename):
raise Exception('Invalid file type selected') raise Exception("Invalid file type selected")
csv_to_df_kwargs = { csv_to_df_kwargs = {
'filepath_or_buffer': filename, "filepath_or_buffer": filename,
'sep': form.sep.data, "sep": form.sep.data,
'header': form.header.data if form.header.data else 0, "header": form.header.data if form.header.data else 0,
'index_col': form.index_col.data, "index_col": form.index_col.data,
'mangle_dupe_cols': form.mangle_dupe_cols.data, "mangle_dupe_cols": form.mangle_dupe_cols.data,
'skipinitialspace': form.skipinitialspace.data, "skipinitialspace": form.skipinitialspace.data,
'skiprows': form.skiprows.data, "skiprows": form.skiprows.data,
'nrows': form.nrows.data, "nrows": form.nrows.data,
'skip_blank_lines': form.skip_blank_lines.data, "skip_blank_lines": form.skip_blank_lines.data,
'parse_dates': form.parse_dates.data, "parse_dates": form.parse_dates.data,
'infer_datetime_format': form.infer_datetime_format.data, "infer_datetime_format": form.infer_datetime_format.data,
'chunksize': 10000, "chunksize": 10000,
} }
df = cls.csv_to_df(**csv_to_df_kwargs) df = cls.csv_to_df(**csv_to_df_kwargs)
df_to_sql_kwargs = { df_to_sql_kwargs = {
'df': df, "df": df,
'name': form.name.data, "name": form.name.data,
'con': create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False), "con": create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
'schema': form.schema.data, "schema": form.schema.data,
'if_exists': form.if_exists.data, "if_exists": form.if_exists.data,
'index': form.index.data, "index": form.index.data,
'index_label': form.index_label.data, "index_label": form.index_label.data,
'chunksize': 10000, "chunksize": 10000,
} }
cls.df_to_sql(**df_to_sql_kwargs) cls.df_to_sql(**df_to_sql_kwargs)
@ -304,34 +309,41 @@ class BaseEngineSpec(object):
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod @classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \ def get_all_datasource_names(
-> List[utils.DatasourceName]: cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
"""Returns a list of all tables or views in database. """Returns a list of all tables or views in database.
:param db: Database instance :param db: Database instance
:param datasource_type: Datasource_type can be 'table' or 'view' :param datasource_type: Datasource_type can be 'table' or 'view'
:return: List of all datasources in database or schema :return: List of all datasources in database or schema
""" """
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled, schemas = db.get_all_schema_names(
cache_timeout=db.schema_cache_timeout, cache=db.schema_cache_enabled,
force=True) cache_timeout=db.schema_cache_timeout,
force=True,
)
all_datasources: List[utils.DatasourceName] = [] all_datasources: List[utils.DatasourceName] = []
for schema in schemas: for schema in schemas:
if datasource_type == 'table': if datasource_type == "table":
all_datasources += db.get_all_table_names_in_schema( all_datasources += db.get_all_table_names_in_schema(
schema=schema, force=True, schema=schema,
force=True,
cache=db.table_cache_enabled, cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout) cache_timeout=db.table_cache_timeout,
elif datasource_type == 'view': )
elif datasource_type == "view":
all_datasources += db.get_all_view_names_in_schema( all_datasources += db.get_all_view_names_in_schema(
schema=schema, force=True, schema=schema,
force=True,
cache=db.table_cache_enabled, cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout) cache_timeout=db.table_cache_timeout,
)
else: else:
raise Exception(f'Unsupported datasource_type: {datasource_type}') raise Exception(f"Unsupported datasource_type: {datasource_type}")
return all_datasources return all_datasources
@classmethod @classmethod
@ -381,14 +393,14 @@ class BaseEngineSpec(object):
def get_table_names(cls, inspector, schema): def get_table_names(cls, inspector, schema):
tables = inspector.get_table_names(schema) tables = inspector.get_table_names(schema)
if schema and cls.try_remove_schema_from_table_name: if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables] tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
return sorted(tables) return sorted(tables)
@classmethod @classmethod
def get_view_names(cls, inspector, schema): def get_view_names(cls, inspector, schema):
views = inspector.get_view_names(schema) views = inspector.get_view_names(schema)
if schema and cls.try_remove_schema_from_table_name: if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f'^{schema}\\.', '', view) for view in views] views = [re.sub(f"^{schema}\\.", "", view) for view in views]
return sorted(views) return sorted(views)
@classmethod @classmethod
@ -396,19 +408,27 @@ class BaseEngineSpec(object):
return inspector.get_columns(table_name, schema) return inspector.get_columns(table_name, schema)
@classmethod @classmethod
def where_latest_partition( def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
cls, table_name, schema, database, qry, columns=None):
return False return False
@classmethod @classmethod
def _get_fields(cls, cols): def _get_fields(cls, cols):
return [column(c.get('name')) for c in cols] return [column(c.get("name")) for c in cols]
@classmethod @classmethod
def select_star(cls, my_db, table_name, engine, schema=None, limit=100, def select_star(
show_cols=False, indent=True, latest_partition=True, cls,
cols=None): my_db,
fields = '*' table_name,
engine,
schema=None,
limit=100,
show_cols=False,
indent=True,
latest_partition=True,
cols=None,
):
fields = "*"
cols = cols or [] cols = cols or []
if (show_cols or latest_partition) and not cols: if (show_cols or latest_partition) and not cols:
cols = my_db.get_columns(table_name, schema) cols = my_db.get_columns(table_name, schema)
@ -417,7 +437,7 @@ class BaseEngineSpec(object):
fields = cls._get_fields(cols) fields = cls._get_fields(cols)
quote = engine.dialect.identifier_preparer.quote quote = engine.dialect.identifier_preparer.quote
if schema: if schema:
full_table_name = quote(schema) + '.' + quote(table_name) full_table_name = quote(schema) + "." + quote(table_name)
else: else:
full_table_name = quote(table_name) full_table_name = quote(table_name)
@ -427,7 +447,8 @@ class BaseEngineSpec(object):
qry = qry.limit(limit) qry = qry.limit(limit)
if latest_partition: if latest_partition:
partition_query = cls.where_latest_partition( partition_query = cls.where_latest_partition(
table_name, schema, my_db, qry, columns=cols) table_name, schema, my_db, qry, columns=cols
)
if partition_query != False: # noqa if partition_query != False: # noqa
qry = partition_query qry = partition_query
sql = my_db.compile_sqla_query(qry) sql = my_db.compile_sqla_query(qry)
@ -475,7 +496,10 @@ class BaseEngineSpec(object):
generate a truncated label by calling truncate_label(). generate a truncated label by calling truncate_label().
""" """
label_mutated = cls.mutate_label(label) label_mutated = cls.mutate_label(label)
if cls.max_column_name_length and len(label_mutated) > cls.max_column_name_length: if (
cls.max_column_name_length
and len(label_mutated) > cls.max_column_name_length
):
label_mutated = cls.truncate_label(label) label_mutated = cls.truncate_label(label)
if cls.force_column_alias_quotes: if cls.force_column_alias_quotes:
label_mutated = quoted_name(label_mutated, True) label_mutated = quoted_name(label_mutated, True)
@ -510,10 +534,10 @@ class BaseEngineSpec(object):
this method is used to construct a deterministic and unique label based on this method is used to construct a deterministic and unique label based on
an md5 hash. an md5 hash.
""" """
label = hashlib.md5(label.encode('utf-8')).hexdigest() label = hashlib.md5(label.encode("utf-8")).hexdigest()
# truncate hash if it exceeds max length # truncate hash if it exceeds max length
if cls.max_column_name_length and len(label) > cls.max_column_name_length: if cls.max_column_name_length and len(label) > cls.max_column_name_length:
label = label[:cls.max_column_name_length] label = label[: cls.max_column_name_length]
return label return label
@classmethod @classmethod

View File

@ -27,7 +27,8 @@ class BigQueryEngineSpec(BaseEngineSpec):
"""Engine spec for Google's BigQuery """Engine spec for Google's BigQuery
As contributed by @mxmzdlv on issue #945""" As contributed by @mxmzdlv on issue #945"""
engine = 'bigquery'
engine = "bigquery"
max_column_name_length = 128 max_column_name_length = 128
""" """
@ -43,28 +44,28 @@ class BigQueryEngineSpec(BaseEngineSpec):
arraysize = 5000 arraysize = 5000
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'TIMESTAMP_TRUNC({col}, SECOND)', "PT1S": "TIMESTAMP_TRUNC({col}, SECOND)",
'PT1M': 'TIMESTAMP_TRUNC({col}, MINUTE)', "PT1M": "TIMESTAMP_TRUNC({col}, MINUTE)",
'PT1H': 'TIMESTAMP_TRUNC({col}, HOUR)', "PT1H": "TIMESTAMP_TRUNC({col}, HOUR)",
'P1D': 'TIMESTAMP_TRUNC({col}, DAY)', "P1D": "TIMESTAMP_TRUNC({col}, DAY)",
'P1W': 'TIMESTAMP_TRUNC({col}, WEEK)', "P1W": "TIMESTAMP_TRUNC({col}, WEEK)",
'P1M': 'TIMESTAMP_TRUNC({col}, MONTH)', "P1M": "TIMESTAMP_TRUNC({col}, MONTH)",
'P0.25Y': 'TIMESTAMP_TRUNC({col}, QUARTER)', "P0.25Y": "TIMESTAMP_TRUNC({col}, QUARTER)",
'P1Y': 'TIMESTAMP_TRUNC({col}, YEAR)', "P1Y": "TIMESTAMP_TRUNC({col}, YEAR)",
} }
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "'{}'".format(dttm.strftime('%Y-%m-%d')) return "'{}'".format(dttm.strftime("%Y-%m-%d"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod @classmethod
def fetch_data(cls, cursor, limit): def fetch_data(cls, cursor, limit):
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit) data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == 'Row': if data and type(data[0]).__name__ == "Row":
data = [r.values() for r in data] data = [r.values() for r in data]
return data return data
@ -78,13 +79,13 @@ class BigQueryEngineSpec(BaseEngineSpec):
:param str label: the original label which might include unsupported characters :param str label: the original label which might include unsupported characters
:return: String that is supported by the database :return: String that is supported by the database
""" """
label_hashed = '_' + hashlib.md5(label.encode('utf-8')).hexdigest() label_hashed = "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
# if label starts with number, add underscore as first character # if label starts with number, add underscore as first character
label_mutated = '_' + label if re.match(r'^\d', label) else label label_mutated = "_" + label if re.match(r"^\d", label) else label
# replace non-alphanumeric characters with underscores # replace non-alphanumeric characters with underscores
label_mutated = re.sub(r'[^\w]+', '_', label_mutated) label_mutated = re.sub(r"[^\w]+", "_", label_mutated)
if label_mutated != label: if label_mutated != label:
# add first 5 chars from md5 hash to label to avoid possible collisions # add first 5 chars from md5 hash to label to avoid possible collisions
label_mutated += label_hashed[:6] label_mutated += label_hashed[:6]
@ -97,7 +98,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
underscore. To make sure this is always the case, an underscore is prefixed underscore. To make sure this is always the case, an underscore is prefixed
to the truncated label. to the truncated label.
""" """
return '_' + hashlib.md5(label.encode('utf-8')).hexdigest() return "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
@classmethod @classmethod
def extra_table_metadata(cls, database, table_name, schema_name): def extra_table_metadata(cls, database, table_name, schema_name):
@ -105,20 +106,18 @@ class BigQueryEngineSpec(BaseEngineSpec):
if not indexes: if not indexes:
return {} return {}
partitions_columns = [ partitions_columns = [
index.get('column_names', []) for index in indexes index.get("column_names", [])
if index.get('name') == 'partition' for index in indexes
if index.get("name") == "partition"
] ]
cluster_columns = [ cluster_columns = [
index.get('column_names', []) for index in indexes index.get("column_names", [])
if index.get('name') == 'clustering' for index in indexes
if index.get("name") == "clustering"
] ]
return { return {
'partitions': { "partitions": {"cols": partitions_columns},
'cols': partitions_columns, "clustering": {"cols": cluster_columns},
},
'clustering': {
'cols': cluster_columns,
},
} }
@classmethod @classmethod
@ -131,16 +130,18 @@ class BigQueryEngineSpec(BaseEngineSpec):
Also explicility specifying column names so we don't encounter duplicate Also explicility specifying column names so we don't encounter duplicate
column names in the result. column names in the result.
""" """
return [literal_column(c.get('name')).label(c.get('name').replace('.', '__')) return [
for c in cols] literal_column(c.get("name")).label(c.get("name").replace(".", "__"))
for c in cols
]
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return 'TIMESTAMP_SECONDS({col})' return "TIMESTAMP_SECONDS({col})"
@classmethod @classmethod
def epoch_ms_to_dttm(cls): def epoch_ms_to_dttm(cls):
return 'TIMESTAMP_MILLIS({col})' return "TIMESTAMP_MILLIS({col})"
@classmethod @classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs): def df_to_sql(cls, df: pd.DataFrame, **kwargs):
@ -156,18 +157,20 @@ class BigQueryEngineSpec(BaseEngineSpec):
try: try:
import pandas_gbq import pandas_gbq
except ImportError: except ImportError:
raise Exception('Could not import the library `pandas_gbq`, which is ' raise Exception(
'required to be installed in your environment in order ' "Could not import the library `pandas_gbq`, which is "
'to upload data to BigQuery') "required to be installed in your environment in order "
"to upload data to BigQuery"
)
if not ('name' in kwargs and 'schema' in kwargs): if not ("name" in kwargs and "schema" in kwargs):
raise Exception('name and schema need to be defined in kwargs') raise Exception("name and schema need to be defined in kwargs")
gbq_kwargs = {} gbq_kwargs = {}
gbq_kwargs['project_id'] = kwargs['con'].engine.url.host gbq_kwargs["project_id"] = kwargs["con"].engine.url.host
gbq_kwargs['destination_table'] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}" gbq_kwargs["destination_table"] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"
# Only pass through supported kwargs # Only pass through supported kwargs
supported_kwarg_keys = {'if_exists'} supported_kwarg_keys = {"if_exists"}
for key in supported_kwarg_keys: for key in supported_kwarg_keys:
if key in kwargs: if key in kwargs:
gbq_kwargs[key] = kwargs[key] gbq_kwargs[key] = kwargs[key]

View File

@ -21,32 +21,31 @@ from superset.db_engine_specs.base import BaseEngineSpec
class ClickHouseEngineSpec(BaseEngineSpec): class ClickHouseEngineSpec(BaseEngineSpec):
"""Dialect for ClickHouse analytical DB.""" """Dialect for ClickHouse analytical DB."""
engine = 'clickhouse' engine = "clickhouse"
time_secondary_columns = True time_secondary_columns = True
time_groupby_inline = True time_groupby_inline = True
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1M': 'toStartOfMinute(toDateTime({col}))', "PT1M": "toStartOfMinute(toDateTime({col}))",
'PT5M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)', "PT5M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)",
'PT10M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)', "PT10M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)",
'PT15M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)', "PT15M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)",
'PT0.5H': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)', "PT0.5H": "toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)",
'PT1H': 'toStartOfHour(toDateTime({col}))', "PT1H": "toStartOfHour(toDateTime({col}))",
'P1D': 'toStartOfDay(toDateTime({col}))', "P1D": "toStartOfDay(toDateTime({col}))",
'P1W': 'toMonday(toDateTime({col}))', "P1W": "toMonday(toDateTime({col}))",
'P1M': 'toStartOfMonth(toDateTime({col}))', "P1M": "toStartOfMonth(toDateTime({col}))",
'P0.25Y': 'toStartOfQuarter(toDateTime({col}))', "P0.25Y": "toStartOfQuarter(toDateTime({col}))",
'P1Y': 'toStartOfYear(toDateTime({col}))', "P1Y": "toStartOfYear(toDateTime({col}))",
} }
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "toDate('{}')".format(dttm.strftime('%Y-%m-%d')) return "toDate('{}')".format(dttm.strftime("%Y-%m-%d"))
if tt == 'DATETIME': if tt == "DATETIME":
return "toDateTime('{}')".format( return "toDateTime('{}')".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))

View File

@ -19,34 +19,32 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class Db2EngineSpec(BaseEngineSpec): class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa' engine = "ibm_db_sa"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True force_column_alias_quotes = True
max_column_name_length = 30 max_column_name_length = 30
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'CAST({col} as TIMESTAMP)' "PT1S": "CAST({col} as TIMESTAMP)" " - MICROSECOND({col}) MICROSECONDS",
' - MICROSECOND({col}) MICROSECONDS', "PT1M": "CAST({col} as TIMESTAMP)"
'PT1M': 'CAST({col} as TIMESTAMP)' " - SECOND({col}) SECONDS"
' - SECOND({col}) SECONDS' " - MICROSECOND({col}) MICROSECONDS",
' - MICROSECOND({col}) MICROSECONDS', "PT1H": "CAST({col} as TIMESTAMP)"
'PT1H': 'CAST({col} as TIMESTAMP)' " - MINUTE({col}) MINUTES"
' - MINUTE({col}) MINUTES' " - SECOND({col}) SECONDS"
' - SECOND({col}) SECONDS' " - MICROSECOND({col}) MICROSECONDS ",
' - MICROSECOND({col}) MICROSECONDS ', "P1D": "CAST({col} as TIMESTAMP)"
'P1D': 'CAST({col} as TIMESTAMP)' " - HOUR({col}) HOURS"
' - HOUR({col}) HOURS' " - MINUTE({col}) MINUTES"
' - MINUTE({col}) MINUTES' " - SECOND({col}) SECONDS"
' - SECOND({col}) SECONDS' " - MICROSECOND({col}) MICROSECONDS",
' - MICROSECOND({col}) MICROSECONDS', "P1W": "{col} - (DAYOFWEEK({col})) DAYS",
'P1W': '{col} - (DAYOFWEEK({col})) DAYS', "P1M": "{col} - (DAY({col})-1) DAYS",
'P1M': '{col} - (DAY({col})-1) DAYS', "P0.25Y": "{col} - (DAY({col})-1) DAYS"
'P0.25Y': '{col} - (DAY({col})-1) DAYS' " - (MONTH({col})-1) MONTHS"
' - (MONTH({col})-1) MONTHS' " + ((QUARTER({col})-1) * 3) MONTHS",
' + ((QUARTER({col})-1) * 3) MONTHS', "P1Y": "{col} - (DAY({col})-1) DAYS" " - (MONTH({col})-1) MONTHS",
'P1Y': '{col} - (DAY({col})-1) DAYS'
' - (MONTH({col})-1) MONTHS',
} }
@classmethod @classmethod
@ -55,4 +53,4 @@ class Db2EngineSpec(BaseEngineSpec):
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d-%H.%M.%S"))

View File

@ -22,43 +22,43 @@ from superset.db_engine_specs.base import BaseEngineSpec
class DrillEngineSpec(BaseEngineSpec): class DrillEngineSpec(BaseEngineSpec):
"""Engine spec for Apache Drill""" """Engine spec for Apache Drill"""
engine = 'drill'
engine = "drill"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "NEARESTDATE({col}, 'SECOND')", "PT1S": "NEARESTDATE({col}, 'SECOND')",
'PT1M': "NEARESTDATE({col}, 'MINUTE')", "PT1M": "NEARESTDATE({col}, 'MINUTE')",
'PT15M': "NEARESTDATE({col}, 'QUARTER_HOUR')", "PT15M": "NEARESTDATE({col}, 'QUARTER_HOUR')",
'PT0.5H': "NEARESTDATE({col}, 'HALF_HOUR')", "PT0.5H": "NEARESTDATE({col}, 'HALF_HOUR')",
'PT1H': "NEARESTDATE({col}, 'HOUR')", "PT1H": "NEARESTDATE({col}, 'HOUR')",
'P1D': "NEARESTDATE({col}, 'DAY')", "P1D": "NEARESTDATE({col}, 'DAY')",
'P1W': "NEARESTDATE({col}, 'WEEK_SUNDAY')", "P1W": "NEARESTDATE({col}, 'WEEK_SUNDAY')",
'P1M': "NEARESTDATE({col}, 'MONTH')", "P1M": "NEARESTDATE({col}, 'MONTH')",
'P0.25Y': "NEARESTDATE({col}, 'QUARTER')", "P0.25Y": "NEARESTDATE({col}, 'QUARTER')",
'P1Y': "NEARESTDATE({col}, 'YEAR')", "P1Y": "NEARESTDATE({col}, 'YEAR')",
} }
# Returns a function to convert a Unix timestamp in milliseconds to a date # Returns a function to convert a Unix timestamp in milliseconds to a date
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return cls.epoch_ms_to_dttm().replace('{col}', '({col}*1000)') return cls.epoch_ms_to_dttm().replace("{col}", "({col}*1000)")
@classmethod @classmethod
def epoch_ms_to_dttm(cls): def epoch_ms_to_dttm(cls):
return 'TO_DATE({col})' return "TO_DATE({col})"
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10]) return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
elif tt == 'TIMESTAMP': elif tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format( return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod @classmethod
def adjust_database_uri(cls, uri, selected_schema): def adjust_database_uri(cls, uri, selected_schema):
if selected_schema: if selected_schema:
uri.database = parse.quote(selected_schema, safe='') uri.database = parse.quote(selected_schema, safe="")
return uri return uri

View File

@ -20,23 +20,24 @@ from superset.db_engine_specs.base import BaseEngineSpec
class DruidEngineSpec(BaseEngineSpec): class DruidEngineSpec(BaseEngineSpec):
"""Engine spec for Druid.io""" """Engine spec for Druid.io"""
engine = 'druid'
engine = "druid"
inner_joins = False inner_joins = False
allows_subquery = False allows_subquery = False
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'FLOOR({col} TO SECOND)', "PT1S": "FLOOR({col} TO SECOND)",
'PT1M': 'FLOOR({col} TO MINUTE)', "PT1M": "FLOOR({col} TO MINUTE)",
'PT1H': 'FLOOR({col} TO HOUR)', "PT1H": "FLOOR({col} TO HOUR)",
'P1D': 'FLOOR({col} TO DAY)', "P1D": "FLOOR({col} TO DAY)",
'P1W': 'FLOOR({col} TO WEEK)', "P1W": "FLOOR({col} TO WEEK)",
'P1M': 'FLOOR({col} TO MONTH)', "P1M": "FLOOR({col} TO MONTH)",
'P0.25Y': 'FLOOR({col} TO QUARTER)', "P0.25Y": "FLOOR({col} TO QUARTER)",
'P1Y': 'FLOOR({col} TO YEAR)', "P1Y": "FLOOR({col} TO YEAR)",
} }
@classmethod @classmethod
def alter_new_orm_column(cls, orm_col): def alter_new_orm_column(cls, orm_col):
if orm_col.column_name == '__time': if orm_col.column_name == "__time":
orm_col.is_dttm = True orm_col.is_dttm = True

View File

@ -20,6 +20,7 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
class GSheetsEngineSpec(SqliteEngineSpec): class GSheetsEngineSpec(SqliteEngineSpec):
"""Engine for Google spreadsheets""" """Engine for Google spreadsheets"""
engine = 'gsheets'
engine = "gsheets"
inner_joins = False inner_joins = False
allows_subquery = False allows_subquery = False

View File

@ -38,30 +38,30 @@ from superset.utils import core as utils
QueryStatus = utils.QueryStatus QueryStatus = utils.QueryStatus
config = app.config config = app.config
tracking_url_trans = conf.get('TRACKING_URL_TRANSFORMER') tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
hive_poll_interval = conf.get('HIVE_POLL_INTERVAL') hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
class HiveEngineSpec(PrestoEngineSpec): class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality.""" """Reuses PrestoEngineSpec functionality."""
engine = 'hive' engine = "hive"
max_column_name_length = 767 max_column_name_length = 767
# Scoping regex at class level to avoid recompiling # Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5 # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
jobs_stats_r = re.compile( jobs_stats_r = re.compile(r".*INFO.*Total jobs = (?P<max_jobs>[0-9]+)")
r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
# 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5 # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
launching_job_r = re.compile( launching_job_r = re.compile(
'.*INFO.*Launching Job (?P<job_number>[0-9]+) out of ' ".*INFO.*Launching Job (?P<job_number>[0-9]+) out of " "(?P<max_jobs>[0-9]+)"
'(?P<max_jobs>[0-9]+)') )
# 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18 # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
# map = 0%, reduce = 0% # map = 0%, reduce = 0%
stage_progress_r = re.compile( stage_progress_r = re.compile(
r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*' r".*INFO.*Stage-(?P<stage_number>[0-9]+).*"
r'map = (?P<map_progress>[0-9]+)%.*' r"map = (?P<map_progress>[0-9]+)%.*"
r'reduce = (?P<reduce_progress>[0-9]+)%.*') r"reduce = (?P<reduce_progress>[0-9]+)%.*"
)
@classmethod @classmethod
def patch(cls): def patch(cls):
@ -70,7 +70,8 @@ class HiveEngineSpec(PrestoEngineSpec):
from TCLIService import ( from TCLIService import (
constants as patched_constants, constants as patched_constants,
ttypes as patched_ttypes, ttypes as patched_ttypes,
TCLIService as patched_TCLIService) TCLIService as patched_TCLIService,
)
hive.TCLIService = patched_TCLIService hive.TCLIService = patched_TCLIService
hive.constants = patched_constants hive.constants = patched_constants
@ -78,17 +79,19 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = patched_hive.fetch_logs hive.Cursor.fetch_logs = patched_hive.fetch_logs
@classmethod @classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \ def get_all_datasource_names(
-> List[utils.DatasourceName]: cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(db, datasource_type) return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
@classmethod @classmethod
def fetch_data(cls, cursor, limit): def fetch_data(cls, cursor, limit):
import pyhive import pyhive
from TCLIService import ttypes from TCLIService import ttypes
state = cursor.poll() state = cursor.poll()
if state.operationState == ttypes.TOperationState.ERROR_STATE: if state.operationState == ttypes.TOperationState.ERROR_STATE:
raise Exception('Query error', state.errorMessage) raise Exception("Query error", state.errorMessage)
try: try:
return super(HiveEngineSpec, cls).fetch_data(cursor, limit) return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
except pyhive.exc.ProgrammingError: except pyhive.exc.ProgrammingError:
@ -97,68 +100,76 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod @classmethod
def create_table_from_csv(cls, form, table): def create_table_from_csv(cls, form, table):
"""Uploads a csv file and creates a superset datasource in Hive.""" """Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type): def convert_to_hive_type(col_type):
"""maps tableschema's types to hive types""" """maps tableschema's types to hive types"""
tableschema_to_hive_types = { tableschema_to_hive_types = {
'boolean': 'BOOLEAN', "boolean": "BOOLEAN",
'integer': 'INT', "integer": "INT",
'number': 'DOUBLE', "number": "DOUBLE",
'string': 'STRING', "string": "STRING",
} }
return tableschema_to_hive_types.get(col_type, 'STRING') return tableschema_to_hive_types.get(col_type, "STRING")
bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET'] bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
if not bucket_path: if not bucket_path:
logging.info('No upload bucket specified') logging.info("No upload bucket specified")
raise Exception( raise Exception(
'No upload bucket specified. You can specify one in the config file.') "No upload bucket specified. You can specify one in the config file."
)
table_name = form.name.data table_name = form.name.data
schema_name = form.schema.data schema_name = form.schema.data
if config.get('UPLOADED_CSV_HIVE_NAMESPACE'): if config.get("UPLOADED_CSV_HIVE_NAMESPACE"):
if '.' in table_name or schema_name: if "." in table_name or schema_name:
raise Exception( raise Exception(
"You can't specify a namespace. " "You can't specify a namespace. "
'All tables will be uploaded to the `{}` namespace'.format( "All tables will be uploaded to the `{}` namespace".format(
config.get('HIVE_NAMESPACE'))) config.get("HIVE_NAMESPACE")
full_table_name = '{}.{}'.format( )
config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name) )
full_table_name = "{}.{}".format(
config.get("UPLOADED_CSV_HIVE_NAMESPACE"), table_name
)
else: else:
if '.' in table_name and schema_name: if "." in table_name and schema_name:
raise Exception( raise Exception(
"You can't specify a namespace both in the name of the table " "You can't specify a namespace both in the name of the table "
'and in the schema field. Please remove one') "and in the schema field. Please remove one"
)
full_table_name = '{}.{}'.format( full_table_name = (
schema_name, table_name) if schema_name else table_name "{}.{}".format(schema_name, table_name) if schema_name else table_name
)
filename = form.csv_file.data.filename filename = form.csv_file.data.filename
upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY'] upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
upload_path = config['UPLOAD_FOLDER'] + \ upload_path = config["UPLOAD_FOLDER"] + secure_filename(filename)
secure_filename(filename)
# Optional dependency # Optional dependency
from tableschema import Table # pylint: disable=import-error from tableschema import Table # pylint: disable=import-error
hive_table_schema = Table(upload_path).infer() hive_table_schema = Table(upload_path).infer()
column_name_and_type = [] column_name_and_type = []
for column_info in hive_table_schema['fields']: for column_info in hive_table_schema["fields"]:
column_name_and_type.append( column_name_and_type.append(
'`{}` {}'.format( "`{}` {}".format(
column_info['name'], column_info["name"], convert_to_hive_type(column_info["type"])
convert_to_hive_type(column_info['type']))) )
schema_definition = ', '.join(column_name_and_type) )
schema_definition = ", ".join(column_name_and_type)
# Optional dependency # Optional dependency
import boto3 # pylint: disable=import-error import boto3 # pylint: disable=import-error
s3 = boto3.client('s3') s3 = boto3.client("s3")
location = os.path.join('s3a://', bucket_path, upload_prefix, table_name) location = os.path.join("s3a://", bucket_path, upload_prefix, table_name)
s3.upload_file( s3.upload_file(
upload_path, bucket_path, upload_path, bucket_path, os.path.join(upload_prefix, table_name, filename)
os.path.join(upload_prefix, table_name, filename)) )
sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} ) sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
TEXTFILE LOCATION '{location}' TEXTFILE LOCATION '{location}'
@ -170,17 +181,16 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10]) return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
elif tt == 'TIMESTAMP': elif tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format( return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod @classmethod
def adjust_database_uri(cls, uri, selected_schema=None): def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema: if selected_schema:
uri.database = parse.quote(selected_schema, safe='') uri.database = parse.quote(selected_schema, safe="")
return uri return uri
@classmethod @classmethod
@ -199,34 +209,32 @@ class HiveEngineSpec(PrestoEngineSpec):
for line in log_lines: for line in log_lines:
match = cls.jobs_stats_r.match(line) match = cls.jobs_stats_r.match(line)
if match: if match:
total_jobs = int(match.groupdict()['max_jobs']) or 1 total_jobs = int(match.groupdict()["max_jobs"]) or 1
match = cls.launching_job_r.match(line) match = cls.launching_job_r.match(line)
if match: if match:
current_job = int(match.groupdict()['job_number']) current_job = int(match.groupdict()["job_number"])
total_jobs = int(match.groupdict()['max_jobs']) or 1 total_jobs = int(match.groupdict()["max_jobs"]) or 1
stages = {} stages = {}
match = cls.stage_progress_r.match(line) match = cls.stage_progress_r.match(line)
if match: if match:
stage_number = int(match.groupdict()['stage_number']) stage_number = int(match.groupdict()["stage_number"])
map_progress = int(match.groupdict()['map_progress']) map_progress = int(match.groupdict()["map_progress"])
reduce_progress = int(match.groupdict()['reduce_progress']) reduce_progress = int(match.groupdict()["reduce_progress"])
stages[stage_number] = (map_progress + reduce_progress) / 2 stages[stage_number] = (map_progress + reduce_progress) / 2
logging.info( logging.info(
'Progress detail: {}, ' "Progress detail: {}, "
'current job {}, ' "current job {}, "
'total jobs: {}'.format(stages, current_job, total_jobs)) "total jobs: {}".format(stages, current_job, total_jobs)
stage_progress = sum(
stages.values()) / len(stages.values()) if stages else 0
progress = (
100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
) )
stage_progress = sum(stages.values()) / len(stages.values()) if stages else 0
progress = 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
return int(progress) return int(progress)
@classmethod @classmethod
def get_tracking_url(cls, log_lines): def get_tracking_url(cls, log_lines):
lkp = 'Tracking URL = ' lkp = "Tracking URL = "
for line in log_lines: for line in log_lines:
if lkp in line: if lkp in line:
return line.split(lkp)[1] return line.split(lkp)[1]
@ -235,6 +243,7 @@ class HiveEngineSpec(PrestoEngineSpec):
def handle_cursor(cls, cursor, query, session): def handle_cursor(cls, cursor, query, session):
"""Updates progress information""" """Updates progress information"""
from pyhive import hive # pylint: disable=no-name-in-module from pyhive import hive # pylint: disable=no-name-in-module
unfinished_states = ( unfinished_states = (
hive.ttypes.TOperationState.INITIALIZED_STATE, hive.ttypes.TOperationState.INITIALIZED_STATE,
hive.ttypes.TOperationState.RUNNING_STATE, hive.ttypes.TOperationState.RUNNING_STATE,
@ -249,11 +258,11 @@ class HiveEngineSpec(PrestoEngineSpec):
cursor.cancel() cursor.cancel()
break break
log = cursor.fetch_logs() or '' log = cursor.fetch_logs() or ""
if log: if log:
log_lines = log.splitlines() log_lines = log.splitlines()
progress = cls.progress(log_lines) progress = cls.progress(log_lines)
logging.info('Progress total: {}'.format(progress)) logging.info("Progress total: {}".format(progress))
needs_commit = False needs_commit = False
if progress > query.progress: if progress > query.progress:
query.progress = progress query.progress = progress
@ -261,21 +270,19 @@ class HiveEngineSpec(PrestoEngineSpec):
if not tracking_url: if not tracking_url:
tracking_url = cls.get_tracking_url(log_lines) tracking_url = cls.get_tracking_url(log_lines)
if tracking_url: if tracking_url:
job_id = tracking_url.split('/')[-2] job_id = tracking_url.split("/")[-2]
logging.info( logging.info("Found the tracking url: {}".format(tracking_url))
'Found the tracking url: {}'.format(tracking_url))
tracking_url = tracking_url_trans(tracking_url) tracking_url = tracking_url_trans(tracking_url)
logging.info( logging.info("Transformation applied: {}".format(tracking_url))
'Transformation applied: {}'.format(tracking_url))
query.tracking_url = tracking_url query.tracking_url = tracking_url
logging.info('Job id: {}'.format(job_id)) logging.info("Job id: {}".format(job_id))
needs_commit = True needs_commit = True
if job_id and len(log_lines) > last_log_line: if job_id and len(log_lines) > last_log_line:
# Wait for job id before logging things out # Wait for job id before logging things out
# this allows for prefixing all log lines and becoming # this allows for prefixing all log lines and becoming
# searchable in something like Kibana # searchable in something like Kibana
for l in log_lines[last_log_line:]: for l in log_lines[last_log_line:]:
logging.info('[{}] {}'.format(job_id, l)) logging.info("[{}] {}".format(job_id, l))
last_log_line = len(log_lines) last_log_line = len(log_lines)
if needs_commit: if needs_commit:
session.commit() session.commit()
@ -284,21 +291,22 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod @classmethod
def get_columns( def get_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]: cls, inspector: Inspector, table_name: str, schema: str
) -> List[dict]:
return inspector.get_columns(table_name, schema) return inspector.get_columns(table_name, schema)
@classmethod @classmethod
def where_latest_partition( def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
cls, table_name, schema, database, qry, columns=None):
try: try:
col_name, value = cls.latest_partition( col_name, value = cls.latest_partition(
table_name, schema, database, show_first=True) table_name, schema, database, show_first=True
)
except Exception: except Exception:
# table is not partitioned # table is not partitioned
return False return False
if value is not None: if value is not None:
for c in columns: for c in columns:
if c.get('name') == col_name: if c.get("name") == col_name:
return qry.where(Column(col_name) == value) return qry.where(Column(col_name) == value)
return False return False
@ -315,20 +323,36 @@ class HiveEngineSpec(PrestoEngineSpec):
def _latest_partition_from_df(cls, df): def _latest_partition_from_df(cls, df):
"""Hive partitions look like ds={partition name}""" """Hive partitions look like ds={partition name}"""
if not df.empty: if not df.empty:
return df.ix[:, 0].max().split('=')[1] return df.ix[:, 0].max().split("=")[1]
@classmethod @classmethod
def _partition_query( def _partition_query(cls, table_name, limit=0, order_by=None, filters=None):
cls, table_name, limit=0, order_by=None, filters=None): return f"SHOW PARTITIONS {table_name}"
return f'SHOW PARTITIONS {table_name}'
@classmethod @classmethod
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None, def select_star(
limit: int = 100, show_cols: bool = False, indent: bool = True, cls,
latest_partition: bool = True, cols: List[dict] = []) -> str: my_db,
table_name: str,
engine: Engine,
schema: str = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: List[dict] = [],
) -> str:
return BaseEngineSpec.select_star( return BaseEngineSpec.select_star(
my_db, table_name, engine, schema, limit, my_db,
show_cols, indent, latest_partition, cols) table_name,
engine,
schema,
limit,
show_cols,
indent,
latest_partition,
cols,
)
@classmethod @classmethod
def modify_url_for_impersonation(cls, url, impersonate_user, username): def modify_url_for_impersonation(cls, url, impersonate_user, username):
@ -356,13 +380,18 @@ class HiveEngineSpec(PrestoEngineSpec):
url = make_url(uri) url = make_url(uri)
backend_name = url.get_backend_name() backend_name = url.get_backend_name()
# Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS # Must be Hive connection, enable impersonation, and set param
if (backend_name == 'hive' and 'auth' in url.query.keys() and # auth=LDAP|KERBEROS
impersonate_user is True and username is not None): if (
configuration['hive.server2.proxy.user'] = username backend_name == "hive"
and "auth" in url.query.keys()
and impersonate_user is True
and username is not None
):
configuration["hive.server2.proxy.user"] = username
return configuration return configuration
@staticmethod @staticmethod
def execute(cursor, query, async_=False): def execute(cursor, query, async_=False):
kwargs = {'async': async_} kwargs = {"async": async_}
cursor.execute(query, **kwargs) cursor.execute(query, **kwargs)

View File

@ -21,32 +21,35 @@ from superset.db_engine_specs.base import BaseEngineSpec
class ImpalaEngineSpec(BaseEngineSpec): class ImpalaEngineSpec(BaseEngineSpec):
"""Engine spec for Cloudera's Impala""" """Engine spec for Cloudera's Impala"""
engine = 'impala' engine = "impala"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1M': "TRUNC({col}, 'MI')", "PT1M": "TRUNC({col}, 'MI')",
'PT1H': "TRUNC({col}, 'HH')", "PT1H": "TRUNC({col}, 'HH')",
'P1D': "TRUNC({col}, 'DD')", "P1D": "TRUNC({col}, 'DD')",
'P1W': "TRUNC({col}, 'WW')", "P1W": "TRUNC({col}, 'WW')",
'P1M': "TRUNC({col}, 'MONTH')", "P1M": "TRUNC({col}, 'MONTH')",
'P0.25Y': "TRUNC({col}, 'Q')", "P0.25Y": "TRUNC({col}, 'Q')",
'P1Y': "TRUNC({col}, 'YYYY')", "P1Y": "TRUNC({col}, 'YYYY')",
} }
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return 'from_unixtime({col})' return "from_unixtime({col})"
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "'{}'".format(dttm.strftime('%Y-%m-%d')) return "'{}'".format(dttm.strftime("%Y-%m-%d"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod @classmethod
def get_schema_names(cls, inspector): def get_schema_names(cls, inspector):
schemas = [row[0] for row in inspector.engine.execute('SHOW SCHEMAS') schemas = [
if not row[0].startswith('_')] row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
if not row[0].startswith("_")
]
return schemas return schemas

View File

@ -21,28 +21,27 @@ from superset.db_engine_specs.base import BaseEngineSpec
class KylinEngineSpec(BaseEngineSpec): class KylinEngineSpec(BaseEngineSpec):
"""Dialect for Apache Kylin""" """Dialect for Apache Kylin"""
engine = 'kylin' engine = "kylin"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)', "PT1S": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)",
'PT1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)', "PT1M": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)",
'PT1H': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)', "PT1H": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)",
'P1D': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)', "P1D": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)",
'P1W': 'CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \ "P1W": "CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)', FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)",
'P1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)', "P1M": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)",
'P0.25Y': 'CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \ "P0.25Y": "CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)', FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)",
'P1Y': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)', "P1Y": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)",
} }
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10]) return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP': if tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format( return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))

View File

@ -23,25 +23,25 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class MssqlEngineSpec(BaseEngineSpec): class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql' engine = "mssql"
epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128 max_column_name_length = 128
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')", "PT1S": "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')",
'PT1M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)', "PT1M": "DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)",
'PT5M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)', "PT5M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)",
'PT10M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)', "PT10M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)",
'PT15M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)', "PT15M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)",
'PT0.5H': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)', "PT0.5H": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)",
'PT1H': 'DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)', "PT1H": "DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)",
'P1D': 'DATEADD(day, DATEDIFF(day, 0, {col}), 0)', "P1D": "DATEADD(day, DATEDIFF(day, 0, {col}), 0)",
'P1W': 'DATEADD(week, DATEDIFF(week, 0, {col}), 0)', "P1W": "DATEADD(week, DATEDIFF(week, 0, {col}), 0)",
'P1M': 'DATEADD(month, DATEDIFF(month, 0, {col}), 0)', "P1M": "DATEADD(month, DATEDIFF(month, 0, {col}), 0)",
'P0.25Y': 'DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)', "P0.25Y": "DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)",
'P1Y': 'DATEADD(year, DATEDIFF(year, 0, {col}), 0)', "P1Y": "DATEADD(year, DATEDIFF(year, 0, {col}), 0)",
} }
@classmethod @classmethod
@ -51,13 +51,13 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def fetch_data(cls, cursor, limit): def fetch_data(cls, cursor, limit):
data = super(MssqlEngineSpec, cls).fetch_data(cursor, limit) data = super(MssqlEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == 'Row': if data and type(data[0]).__name__ == "Row":
data = [[elem for elem in r] for r in data] data = [[elem for elem in r] for r in data]
return data return data
column_types = [ column_types = [
(String(), re.compile(r'^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)', re.IGNORECASE)), (String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r'^N((VAR){0,1}CHAR|TEXT)', re.IGNORECASE)), (UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
] ]
@classmethod @classmethod

View File

@ -22,45 +22,42 @@ from superset.db_engine_specs.base import BaseEngineSpec
class MySQLEngineSpec(BaseEngineSpec): class MySQLEngineSpec(BaseEngineSpec):
engine = 'mysql' engine = "mysql"
max_column_name_length = 64 max_column_name_length = 64
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'DATE_ADD(DATE({col}), ' "PT1S": "DATE_ADD(DATE({col}), "
'INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60' "INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60"
' + SECOND({col})) SECOND)', " + SECOND({col})) SECOND)",
'PT1M': 'DATE_ADD(DATE({col}), ' "PT1M": "DATE_ADD(DATE({col}), "
'INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)', "INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)",
'PT1H': 'DATE_ADD(DATE({col}), ' "PT1H": "DATE_ADD(DATE({col}), " "INTERVAL HOUR({col}) HOUR)",
'INTERVAL HOUR({col}) HOUR)', "P1D": "DATE({col})",
'P1D': 'DATE({col})', "P1W": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFWEEK({col}) - 1 DAY))",
'P1W': 'DATE(DATE_SUB({col}, ' "P1M": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFMONTH({col}) - 1 DAY))",
'INTERVAL DAYOFWEEK({col}) - 1 DAY))', "P0.25Y": "MAKEDATE(YEAR({col}), 1) "
'P1M': 'DATE(DATE_SUB({col}, ' "+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER",
'INTERVAL DAYOFMONTH({col}) - 1 DAY))', "P1Y": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFYEAR({col}) - 1 DAY))",
'P0.25Y': 'MAKEDATE(YEAR({col}), 1) ' "1969-12-29T00:00:00Z/P1W": "DATE(DATE_SUB({col}, "
'+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER', "INTERVAL DAYOFWEEK(DATE_SUB({col}, "
'P1Y': 'DATE(DATE_SUB({col}, ' "INTERVAL 1 DAY)) - 1 DAY))",
'INTERVAL DAYOFYEAR({col}) - 1 DAY))',
'1969-12-29T00:00:00Z/P1W': 'DATE(DATE_SUB({col}, '
'INTERVAL DAYOFWEEK(DATE_SUB({col}, '
'INTERVAL 1 DAY)) - 1 DAY))',
} }
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
if target_type.upper() in ('DATETIME', 'DATE'): if target_type.upper() in ("DATETIME", "DATE"):
return "STR_TO_DATE('{}', '%Y-%m-%d %H:%i:%s')".format( return "STR_TO_DATE('{}', '%Y-%m-%d %H:%i:%s')".format(
dttm.strftime('%Y-%m-%d %H:%M:%S')) dttm.strftime("%Y-%m-%d %H:%M:%S")
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) )
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod @classmethod
def adjust_database_uri(cls, uri, selected_schema=None): def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema: if selected_schema:
uri.database = parse.quote(selected_schema, safe='') uri.database = parse.quote(selected_schema, safe="")
return uri return uri
@classmethod @classmethod
@ -68,11 +65,10 @@ class MySQLEngineSpec(BaseEngineSpec):
if not cls.type_code_map: if not cls.type_code_map:
# only import and store if needed at least once # only import and store if needed at least once
import MySQLdb # pylint: disable=import-error import MySQLdb # pylint: disable=import-error
ft = MySQLdb.constants.FIELD_TYPE ft = MySQLdb.constants.FIELD_TYPE
cls.type_code_map = { cls.type_code_map = {
getattr(ft, k): k getattr(ft, k): k for k in dir(ft) if not k.startswith("_")
for k in dir(ft)
if not k.startswith('_')
} }
datatype = type_code datatype = type_code
if isinstance(type_code, int): if isinstance(type_code, int):
@ -82,7 +78,7 @@ class MySQLEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return 'from_unixtime({col})' return "from_unixtime({col})"
@classmethod @classmethod
def extract_error_message(cls, e): def extract_error_message(cls, e):
@ -101,7 +97,7 @@ class MySQLEngineSpec(BaseEngineSpec):
# MySQL dialect started returning long overflowing datatype # MySQL dialect started returning long overflowing datatype
# as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI' # as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI'
# and we don't need the verbose collation type # and we don't need the verbose collation type
str_cutoff = ' COLLATE ' str_cutoff = " COLLATE "
if str_cutoff in datatype: if str_cutoff in datatype:
datatype = datatype.split(str_cutoff)[0] datatype = datatype.split(str_cutoff)[0]
return datatype return datatype

View File

@ -20,25 +20,25 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class OracleEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle' engine = "oracle"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True force_column_alias_quotes = True
max_column_name_length = 30 max_column_name_length = 30
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': 'CAST({col} as DATE)', "PT1S": "CAST({col} as DATE)",
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')", "PT1M": "TRUNC(CAST({col} as DATE), 'MI')",
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')", "PT1H": "TRUNC(CAST({col} as DATE), 'HH')",
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')", "P1D": "TRUNC(CAST({col} as DATE), 'DDD')",
'P1W': "TRUNC(CAST({col} as DATE), 'WW')", "P1W": "TRUNC(CAST({col} as DATE), 'WW')",
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')", "P1M": "TRUNC(CAST({col} as DATE), 'MONTH')",
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')", "P0.25Y": "TRUNC(CAST({col} as DATE), 'Q')",
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')", "P1Y": "TRUNC(CAST({col} as DATE), 'YEAR')",
} }
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
return ( return ("""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""").format(
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" dttm.isoformat()
).format(dttm.isoformat()) )

View File

@ -23,37 +23,38 @@ from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
class PinotEngineSpec(BaseEngineSpec): class PinotEngineSpec(BaseEngineSpec):
engine = 'pinot' engine = "pinot"
allows_subquery = False allows_subquery = False
inner_joins = False inner_joins = False
supports_column_aliases = False supports_column_aliases = False
# Pinot does its own conversion below # Pinot does its own conversion below
time_grain_functions: Dict[Optional[str], str] = { time_grain_functions: Dict[Optional[str], str] = {
'PT1S': '1:SECONDS', "PT1S": "1:SECONDS",
'PT1M': '1:MINUTES', "PT1M": "1:MINUTES",
'PT1H': '1:HOURS', "PT1H": "1:HOURS",
'P1D': '1:DAYS', "P1D": "1:DAYS",
'P1W': '1:WEEKS', "P1W": "1:WEEKS",
'P1M': '1:MONTHS', "P1M": "1:MONTHS",
'P0.25Y': '3:MONTHS', "P0.25Y": "3:MONTHS",
'P1Y': '1:YEARS', "P1Y": "1:YEARS",
} }
@classmethod @classmethod
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str], def get_timestamp_expr(
time_grain: Optional[str]) -> TimestampExpression: cls, col: ColumnClause, pdf: Optional[str], time_grain: Optional[str]
is_epoch = pdf in ('epoch_s', 'epoch_ms') ) -> TimestampExpression:
is_epoch = pdf in ("epoch_s", "epoch_ms")
if not is_epoch: if not is_epoch:
raise NotImplementedError('Pinot currently only supports epochs') raise NotImplementedError("Pinot currently only supports epochs")
# The DATETIMECONVERT pinot udf is documented at # The DATETIMECONVERT pinot udf is documented at
# Per https://github.com/apache/incubator-pinot/wiki/dateTimeConvert-UDF # Per https://github.com/apache/incubator-pinot/wiki/dateTimeConvert-UDF
# We are not really converting any time units, just bucketing them. # We are not really converting any time units, just bucketing them.
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS' seconds_or_ms = "MILLISECONDS" if pdf == "epoch_ms" else "SECONDS"
tf = f'1:{seconds_or_ms}:EPOCH' tf = f"1:{seconds_or_ms}:EPOCH"
granularity = cls.time_grain_functions.get(time_grain) granularity = cls.time_grain_functions.get(time_grain)
if not granularity: if not granularity:
raise NotImplementedError('No pinot grain spec for ' + str(time_grain)) raise NotImplementedError("No pinot grain spec for " + str(time_grain))
# In pinot the output is a string since there is no timestamp column like pg # In pinot the output is a string since there is no timestamp column like pg
time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")' time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
return TimestampExpression(time_expr, col) return TimestampExpression(time_expr, col)

View File

@ -21,18 +21,18 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """ """ Abstract class for Postgres 'like' databases """
engine = '' engine = ""
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "DATE_TRUNC('second', {col})", "PT1S": "DATE_TRUNC('second', {col})",
'PT1M': "DATE_TRUNC('minute', {col})", "PT1M": "DATE_TRUNC('minute', {col})",
'PT1H': "DATE_TRUNC('hour', {col})", "PT1H": "DATE_TRUNC('hour', {col})",
'P1D': "DATE_TRUNC('day', {col})", "P1D": "DATE_TRUNC('day', {col})",
'P1W': "DATE_TRUNC('week', {col})", "P1W": "DATE_TRUNC('week', {col})",
'P1M': "DATE_TRUNC('month', {col})", "P1M": "DATE_TRUNC('month', {col})",
'P0.25Y': "DATE_TRUNC('quarter', {col})", "P0.25Y": "DATE_TRUNC('quarter', {col})",
'P1Y': "DATE_TRUNC('year', {col})", "P1Y": "DATE_TRUNC('year', {col})",
} }
@classmethod @classmethod
@ -49,11 +49,11 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
class PostgresEngineSpec(PostgresBaseEngineSpec): class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql' engine = "postgresql"
max_column_name_length = 63 max_column_name_length = 63
try_remove_schema_from_table_name = False try_remove_schema_from_table_name = False

View File

@ -38,24 +38,22 @@ QueryStatus = utils.QueryStatus
class PrestoEngineSpec(BaseEngineSpec): class PrestoEngineSpec(BaseEngineSpec):
engine = 'presto' engine = "presto"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))", "PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))",
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))", "PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))",
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))", "PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))",
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))", "P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))",
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))", "P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))",
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))", "P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))",
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))", "P0.25Y": "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))", "P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))",
'P1W/1970-01-03T00:00:00Z': "P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', "
"date_add('day', 5, date_trunc('week', date_add('day', 1, " "date_add('day', 1, CAST({col} AS TIMESTAMP))))",
'CAST({col} AS TIMESTAMP))))', "1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', "
'1969-12-28T00:00:00Z/P1W': "date_add('day', 1, CAST({col} AS TIMESTAMP))))",
"date_add('day', -1, date_trunc('week', "
"date_add('day', 1, CAST({col} AS TIMESTAMP))))",
} }
@classmethod @classmethod
@ -76,10 +74,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param data_type: column data type :param data_type: column data type
:return: column info object :return: column info object
""" """
return { return {"name": name, "type": f"{data_type}"}
'name': name,
'type': f'{data_type}',
}
@classmethod @classmethod
def _get_full_name(cls, names: List[Tuple[str, str]]) -> str: def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
@ -88,7 +83,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param names: list of all individual column names :param names: list of all individual column names
:return: full column name :return: full column name
""" """
return '.'.join(column[0] for column in names if column[0]) return ".".join(column[0] for column in names if column[0])
@classmethod @classmethod
def _has_nested_data_types(cls, component_type: str) -> bool: def _has_nested_data_types(cls, component_type: str) -> bool:
@ -98,10 +93,12 @@ class PrestoEngineSpec(BaseEngineSpec):
:param component_type: data type :param component_type: data type
:return: boolean :return: boolean
""" """
comma_regex = r',(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
white_space_regex = r'\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
return re.search(comma_regex, component_type) is not None \ return (
re.search(comma_regex, component_type) is not None
or re.search(white_space_regex, component_type) is not None or re.search(white_space_regex, component_type) is not None
)
@classmethod @classmethod
def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]: def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
@ -114,38 +111,38 @@ class PrestoEngineSpec(BaseEngineSpec):
:return: list of strings after breaking it by the delimiter :return: list of strings after breaking it by the delimiter
""" """
return re.split( return re.split(
r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type) r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
)
@classmethod @classmethod
def _parse_structural_column(cls, def _parse_structural_column(
parent_column_name: str, cls, parent_column_name: str, parent_data_type: str, result: List[dict]
parent_data_type: str, ) -> None:
result: List[dict]) -> None:
""" """
Parse a row or array column Parse a row or array column
:param result: list tracking the results :param result: list tracking the results
""" """
formatted_parent_column_name = parent_column_name formatted_parent_column_name = parent_column_name
# Quote the column name if there is a space # Quote the column name if there is a space
if ' ' in parent_column_name: if " " in parent_column_name:
formatted_parent_column_name = f'"{parent_column_name}"' formatted_parent_column_name = f'"{parent_column_name}"'
full_data_type = f'{formatted_parent_column_name} {parent_data_type}' full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
original_result_len = len(result) original_result_len = len(result)
# split on open parenthesis ( to get the structural # split on open parenthesis ( to get the structural
# data type and its component types # data type and its component types
data_types = cls._split_data_type(full_data_type, r'\(') data_types = cls._split_data_type(full_data_type, r"\(")
stack: List[Tuple[str, str]] = [] stack: List[Tuple[str, str]] = []
for data_type in data_types: for data_type in data_types:
# split on closed parenthesis ) to track which component # split on closed parenthesis ) to track which component
# types belong to what structural data type # types belong to what structural data type
inner_types = cls._split_data_type(data_type, r'\)') inner_types = cls._split_data_type(data_type, r"\)")
for inner_type in inner_types: for inner_type in inner_types:
# We have finished parsing multiple structural data types # We have finished parsing multiple structural data types
if not inner_type and len(stack) > 0: if not inner_type and len(stack) > 0:
stack.pop() stack.pop()
elif cls._has_nested_data_types(inner_type): elif cls._has_nested_data_types(inner_type):
# split on comma , to get individual data types # split on comma , to get individual data types
single_fields = cls._split_data_type(inner_type, ',') single_fields = cls._split_data_type(inner_type, ",")
for single_field in single_fields: for single_field in single_fields:
single_field = single_field.strip() single_field = single_field.strip()
# If component type starts with a comma, the first single field # If component type starts with a comma, the first single field
@ -153,30 +150,37 @@ class PrestoEngineSpec(BaseEngineSpec):
if not single_field: if not single_field:
continue continue
# split on whitespace to get field name and data type # split on whitespace to get field name and data type
field_info = cls._split_data_type(single_field, r'\s') field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within # check if there is a structural data type within
# overall structural data type # overall structural data type
if field_info[1] == 'array' or field_info[1] == 'row': if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1])) stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack) full_parent_path = cls._get_full_name(stack)
result.append(cls._create_column_info( result.append(
full_parent_path, cls._create_column_info(
presto_type_map[field_info[1]]())) full_parent_path, presto_type_map[field_info[1]]()
)
)
else: # otherwise this field is a basic data type else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack) full_parent_path = cls._get_full_name(stack)
column_name = '{}.{}'.format(full_parent_path, field_info[0]) column_name = "{}.{}".format(
result.append(cls._create_column_info( full_parent_path, field_info[0]
column_name, presto_type_map[field_info[1]]())) )
result.append(
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
)
# If the component type ends with a structural data type, do not pop # If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the # the stack. We have run across a structural data type within the
# overall structural data type. Otherwise, we have completely parsed # overall structural data type. Otherwise, we have completely parsed
# through the entire structural data type and can move on. # through the entire structural data type and can move on.
if not (inner_type.endswith('array') or inner_type.endswith('row')): if not (inner_type.endswith("array") or inner_type.endswith("row")):
stack.pop() stack.pop()
# We have an array of row objects (i.e. array(row(...))) # We have an array of row objects (i.e. array(row(...)))
elif 'array' == inner_type or 'row' == inner_type: elif "array" == inner_type or "row" == inner_type:
# Push a dummy object to represent the structural data type # Push a dummy object to represent the structural data type
stack.append(('', inner_type)) stack.append(("", inner_type))
# We have an array of a basic data types(i.e. array(varchar)). # We have an array of a basic data types(i.e. array(varchar)).
elif len(stack) > 0: elif len(stack) > 0:
# Because it is an array of a basic data type. We have finished # Because it is an array of a basic data type. We have finished
@ -185,12 +189,14 @@ class PrestoEngineSpec(BaseEngineSpec):
# Unquote the column name if necessary # Unquote the column name if necessary
if formatted_parent_column_name != parent_column_name: if formatted_parent_column_name != parent_column_name:
for index in range(original_result_len, len(result)): for index in range(original_result_len, len(result)):
result[index]['name'] = result[index]['name'].replace( result[index]["name"] = result[index]["name"].replace(
formatted_parent_column_name, parent_column_name) formatted_parent_column_name, parent_column_name
)
@classmethod @classmethod
def _show_columns( def _show_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[RowProxy]: cls, inspector: Inspector, table_name: str, schema: str
) -> List[RowProxy]:
""" """
Show presto column names Show presto column names
:param inspector: object that performs database schema inspection :param inspector: object that performs database schema inspection
@ -201,13 +207,14 @@ class PrestoEngineSpec(BaseEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name) full_table = quote(table_name)
if schema: if schema:
full_table = '{}.{}'.format(quote(schema), full_table) full_table = "{}.{}".format(quote(schema), full_table)
columns = inspector.bind.execute('SHOW COLUMNS FROM {}'.format(full_table)) columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns return columns
@classmethod @classmethod
def get_columns( def get_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]: cls, inspector: Inspector, table_name: str, schema: str
) -> List[dict]:
""" """
Get columns from a Presto data source. This includes handling row and Get columns from a Presto data source. This includes handling row and
array data types array data types
@ -222,22 +229,26 @@ class PrestoEngineSpec(BaseEngineSpec):
for column in columns: for column in columns:
try: try:
# parse column if it is a row or array # parse column if it is a row or array
if 'array' in column.Type or 'row' in column.Type: if "array" in column.Type or "row" in column.Type:
structural_column_index = len(result) structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result) cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]['nullable'] = getattr( result[structural_column_index]["nullable"] = getattr(
column, 'Null', True) column, "Null", True
result[structural_column_index]['default'] = None )
result[structural_column_index]["default"] = None
continue continue
else: # otherwise column is a basic data type else: # otherwise column is a basic data type
column_type = presto_type_map[column.Type]() column_type = presto_type_map[column.Type]()
except KeyError: except KeyError:
logging.info('Did not recognize type {} of column {}'.format( logging.info(
column.Type, column.Column)) "Did not recognize type {} of column {}".format(
column.Type, column.Column
)
)
column_type = types.NullType column_type = types.NullType
column_info = cls._create_column_info(column.Column, column_type) column_info = cls._create_column_info(column.Column, column_type)
column_info['nullable'] = getattr(column, 'Null', True) column_info["nullable"] = getattr(column, "Null", True)
column_info['default'] = None column_info["default"] = None
result.append(column_info) result.append(column_info)
return result return result
@ -258,9 +269,9 @@ class PrestoEngineSpec(BaseEngineSpec):
:return: column clauses :return: column clauses
""" """
column_clauses = [] column_clauses = []
# Column names are separated by periods. This regex will find periods in a string # Column names are separated by periods. This regex will find periods in a
# if they are not enclosed in quotes because if a period is enclosed in quotes, # string if they are not enclosed in quotes because if a period is enclosed in
# then that period is part of a column name. # quotes, then that period is part of a column name.
dot_pattern = r"""\. # split on period dot_pattern = r"""\. # split on period
(?= # look ahead (?= # look ahead
(?: # create non-capture group (?: # create non-capture group
@ -269,26 +280,28 @@ class PrestoEngineSpec(BaseEngineSpec):
dot_regex = re.compile(dot_pattern, re.VERBOSE) dot_regex = re.compile(dot_pattern, re.VERBOSE)
for col in cols: for col in cols:
# get individual column names # get individual column names
col_names = re.split(dot_regex, col['name']) col_names = re.split(dot_regex, col["name"])
# quote each column name if it is not already quoted # quote each column name if it is not already quoted
for index, col_name in enumerate(col_names): for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name): if not cls._is_column_name_quoted(col_name):
col_names[index] = '"{}"'.format(col_name) col_names[index] = '"{}"'.format(col_name)
quoted_col_name = '.'.join( quoted_col_name = ".".join(
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"' col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
for col_name in col_names) for col_name in col_names
)
# create column clause in the format "name"."name" AS "name.name" # create column clause in the format "name"."name" AS "name.name"
column_clause = literal_column(quoted_col_name).label(col['name']) column_clause = literal_column(quoted_col_name).label(col["name"])
column_clauses.append(column_clause) column_clauses.append(column_clause)
return column_clauses return column_clauses
@classmethod @classmethod
def _filter_out_array_nested_cols( def _filter_out_array_nested_cols(
cls, cols: List[dict]) -> Tuple[List[dict], List[dict]]: cls, cols: List[dict]
) -> Tuple[List[dict], List[dict]]:
""" """
Filter out columns that correspond to array content. We know which columns to Filter out columns that correspond to array content. We know which columns to
skip because cols is a list provided to us in a specific order where a structural skip because cols is a list provided to us in a specific order where a
column is positioned right before its content. structural column is positioned right before its content.
Example: Column Name: ColA, Column Data Type: array(row(nest_obj int)) Example: Column Name: ColA, Column Data Type: array(row(nest_obj int))
cols = [ ..., ColA, ColA.nest_obj, ... ] cols = [ ..., ColA, ColA.nest_obj, ... ]
@ -296,20 +309,21 @@ class PrestoEngineSpec(BaseEngineSpec):
When we run across an array, check if subsequent column names start with the When we run across an array, check if subsequent column names start with the
array name and skip them. array name and skip them.
:param cols: columns :param cols: columns
:return: filtered list of columns and list of array columns and its nested fields :return: filtered list of columns and list of array columns and its nested
fields
""" """
filtered_cols = [] filtered_cols = []
array_cols = [] array_cols = []
curr_array_col_name = None curr_array_col_name = None
for col in cols: for col in cols:
# col corresponds to an array's content and should be skipped # col corresponds to an array's content and should be skipped
if curr_array_col_name and col['name'].startswith(curr_array_col_name): if curr_array_col_name and col["name"].startswith(curr_array_col_name):
array_cols.append(col) array_cols.append(col)
continue continue
# col is an array so we need to check if subsequent # col is an array so we need to check if subsequent
# columns correspond to the array's contents # columns correspond to the array's contents
elif str(col['type']) == 'ARRAY': elif str(col["type"]) == "ARRAY":
curr_array_col_name = col['name'] curr_array_col_name = col["name"]
array_cols.append(col) array_cols.append(col)
filtered_cols.append(col) filtered_cols.append(col)
else: else:
@ -318,9 +332,18 @@ class PrestoEngineSpec(BaseEngineSpec):
return filtered_cols, array_cols return filtered_cols, array_cols
@classmethod @classmethod
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None, def select_star(
limit: int = 100, show_cols: bool = False, indent: bool = True, cls,
latest_partition: bool = True, cols: List[dict] = []) -> str: my_db,
table_name: str,
engine: Engine,
schema: str = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: List[dict] = [],
) -> str:
""" """
Include selecting properties of row objects. We cannot easily break arrays into Include selecting properties of row objects. We cannot easily break arrays into
rows, so render the whole array in its own row and skip columns that correspond rows, so render the whole array in its own row and skip columns that correspond
@ -328,59 +351,71 @@ class PrestoEngineSpec(BaseEngineSpec):
""" """
presto_cols = cols presto_cols = cols
if show_cols: if show_cols:
dot_regex = r'\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)' dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
presto_cols = [ presto_cols = [
col for col in presto_cols if not re.search(dot_regex, col['name'])] col for col in presto_cols if not re.search(dot_regex, col["name"])
]
return super(PrestoEngineSpec, cls).select_star( return super(PrestoEngineSpec, cls).select_star(
my_db, table_name, engine, schema, limit, my_db,
show_cols, indent, latest_partition, presto_cols, table_name,
engine,
schema,
limit,
show_cols,
indent,
latest_partition,
presto_cols,
) )
@classmethod @classmethod
def adjust_database_uri(cls, uri, selected_schema=None): def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database database = uri.database
if selected_schema and database: if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe='') selected_schema = parse.quote(selected_schema, safe="")
if '/' in database: if "/" in database:
database = database.split('/')[0] + '/' + selected_schema database = database.split("/")[0] + "/" + selected_schema
else: else:
database += '/' + selected_schema database += "/" + selected_schema
uri.database = database uri.database = database
return uri return uri
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()
if tt == 'DATE': if tt == "DATE":
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10]) return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP': if tt == "TIMESTAMP":
return "from_iso8601_timestamp('{}')".format(dttm.isoformat()) return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return 'from_unixtime({col})' return "from_unixtime({col})"
@classmethod @classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \ def get_all_datasource_names(
-> List[utils.DatasourceName]: cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
datasource_df = db.get_df( datasource_df = db.get_df(
'SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S ' "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S "
"ORDER BY concat(table_schema, '.', table_name)".format( "ORDER BY concat(table_schema, '.', table_name)".format(
datasource_type.upper(), datasource_type.upper()
), ),
None) None,
)
datasource_names: List[utils.DatasourceName] = [] datasource_names: List[utils.DatasourceName] = []
for unused, row in datasource_df.iterrows(): for unused, row in datasource_df.iterrows():
datasource_names.append(utils.DatasourceName( datasource_names.append(
schema=row['table_schema'], table=row['table_name'])) utils.DatasourceName(
schema=row["table_schema"], table=row["table_name"]
)
)
return datasource_names return datasource_names
@classmethod @classmethod
def _build_column_hierarchy(cls, def _build_column_hierarchy(
columns: List[dict], cls, columns: List[dict], parent_column_types: List[str], column_hierarchy: dict
parent_column_types: List[str], ) -> None:
column_hierarchy: dict) -> None:
""" """
Build a graph where the root node represents a column whose data type is in Build a graph where the root node represents a column whose data type is in
parent_column_types. A node's children represent that column's nested fields parent_column_types. A node's children represent that column's nested fields
@ -392,28 +427,30 @@ class PrestoEngineSpec(BaseEngineSpec):
if len(columns) == 0: if len(columns) == 0:
return return
root = columns.pop(0) root = columns.pop(0)
root_info = {'type': root['type'], 'children': []} root_info = {"type": root["type"], "children": []}
column_hierarchy[root['name']] = root_info column_hierarchy[root["name"]] = root_info
while columns: while columns:
column = columns[0] column = columns[0]
# If the column name does not start with the root's name, # If the column name does not start with the root's name,
# then this column is not a nested field # then this column is not a nested field
if not column['name'].startswith(f"{root['name']}."): if not column["name"].startswith(f"{root['name']}."):
break break
# If the column's data type is one of the parent types, # If the column's data type is one of the parent types,
# then this column may have nested fields # then this column may have nested fields
if str(column['type']) in parent_column_types: if str(column["type"]) in parent_column_types:
cls._build_column_hierarchy(columns, parent_column_types, cls._build_column_hierarchy(
column_hierarchy) columns, parent_column_types, column_hierarchy
root_info['children'].append(column['name']) )
root_info["children"].append(column["name"])
continue continue
else: # The column is a nested field else: # The column is a nested field
root_info['children'].append(column['name']) root_info["children"].append(column["name"])
columns.pop(0) columns.pop(0)
@classmethod @classmethod
def _create_row_and_array_hierarchy( def _create_row_and_array_hierarchy(
cls, selected_columns: List[dict]) -> Tuple[dict, dict, List[dict]]: cls, selected_columns: List[dict]
) -> Tuple[dict, dict, List[dict]]:
""" """
Build graphs where the root node represents a row or array and its children Build graphs where the root node represents a row or array and its children
are that column's nested fields are that column's nested fields
@ -425,29 +462,30 @@ class PrestoEngineSpec(BaseEngineSpec):
array_column_hierarchy: OrderedDict = OrderedDict() array_column_hierarchy: OrderedDict = OrderedDict()
expanded_columns: List[dict] = [] expanded_columns: List[dict] = []
for column in selected_columns: for column in selected_columns:
if column['type'].startswith('ROW'): if column["type"].startswith("ROW"):
parsed_row_columns: List[dict] = [] parsed_row_columns: List[dict] = []
cls._parse_structural_column(column['name'], cls._parse_structural_column(
column['type'].lower(), column["name"], column["type"].lower(), parsed_row_columns
parsed_row_columns) )
expanded_columns = expanded_columns + parsed_row_columns[1:] expanded_columns = expanded_columns + parsed_row_columns[1:]
filtered_row_columns, array_columns = cls._filter_out_array_nested_cols( filtered_row_columns, array_columns = cls._filter_out_array_nested_cols(
parsed_row_columns) parsed_row_columns
cls._build_column_hierarchy(filtered_row_columns, )
['ROW'], cls._build_column_hierarchy(
row_column_hierarchy) filtered_row_columns, ["ROW"], row_column_hierarchy
cls._build_column_hierarchy(array_columns, )
['ROW', 'ARRAY'], cls._build_column_hierarchy(
array_column_hierarchy) array_columns, ["ROW", "ARRAY"], array_column_hierarchy
elif column['type'].startswith('ARRAY'): )
elif column["type"].startswith("ARRAY"):
parsed_array_columns: List[dict] = [] parsed_array_columns: List[dict] = []
cls._parse_structural_column(column['name'], cls._parse_structural_column(
column['type'].lower(), column["name"], column["type"].lower(), parsed_array_columns
parsed_array_columns) )
expanded_columns = expanded_columns + parsed_array_columns[1:] expanded_columns = expanded_columns + parsed_array_columns[1:]
cls._build_column_hierarchy(parsed_array_columns, cls._build_column_hierarchy(
['ROW', 'ARRAY'], parsed_array_columns, ["ROW", "ARRAY"], array_column_hierarchy
array_column_hierarchy) )
return row_column_hierarchy, array_column_hierarchy, expanded_columns return row_column_hierarchy, array_column_hierarchy, expanded_columns
@classmethod @classmethod
@ -457,7 +495,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param columns: list of columns :param columns: list of columns
:return: dictionary representing an empty row of data :return: dictionary representing an empty row of data
""" """
return {column['name']: '' for column in columns} return {column["name"]: "" for column in columns}
@classmethod @classmethod
def _expand_row_data(cls, datum: dict, column: str, column_hierarchy: dict) -> None: def _expand_row_data(cls, datum: dict, column: str, column_hierarchy: dict) -> None:
@ -470,22 +508,23 @@ class PrestoEngineSpec(BaseEngineSpec):
""" """
if column in datum: if column in datum:
row_data = datum[column] row_data = datum[column]
row_children = column_hierarchy[column]['children'] row_children = column_hierarchy[column]["children"]
if row_data and len(row_data) != len(row_children): if row_data and len(row_data) != len(row_children):
raise Exception('The number of data values and number of nested' raise Exception(
'fields are not equal') "The number of data values and number of nested"
"fields are not equal"
)
elif row_data: elif row_data:
for index, data_value in enumerate(row_data): for index, data_value in enumerate(row_data):
datum[row_children[index]] = data_value datum[row_children[index]] = data_value
else: else:
for row_child in row_children: for row_child in row_children:
datum[row_child] = '' datum[row_child] = ""
@classmethod @classmethod
def _split_array_columns_by_process_state( def _split_array_columns_by_process_state(
cls, array_columns: List[str], cls, array_columns: List[str], array_column_hierarchy: dict, datum: dict
array_column_hierarchy: dict, ) -> Tuple[List[str], Set[str]]:
datum: dict) -> Tuple[List[str], Set[str]]:
""" """
Take a list of array columns and split them according to whether or not we are Take a list of array columns and split them according to whether or not we are
ready to process them from a data set ready to process them from a data set
@ -501,7 +540,7 @@ class PrestoEngineSpec(BaseEngineSpec):
for array_column in array_columns: for array_column in array_columns:
if array_column in datum: if array_column in datum:
array_columns_to_process.append(array_column) array_columns_to_process.append(array_column)
elif str(array_column_hierarchy[array_column]['type']) == 'ARRAY': elif str(array_column_hierarchy[array_column]["type"]) == "ARRAY":
child_array = array_column child_array = array_column
unprocessed_array_columns.add(child_array) unprocessed_array_columns.add(child_array)
elif child_array and array_column.startswith(child_array): elif child_array and array_column.startswith(child_array):
@ -510,7 +549,8 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def _convert_data_list_to_array_data_dict( def _convert_data_list_to_array_data_dict(
cls, data: List[dict], array_columns_to_process: List[str]) -> dict: cls, data: List[dict], array_columns_to_process: List[str]
) -> dict:
""" """
Pull out array data from rows of data into a dictionary where the key represents Pull out array data from rows of data into a dictionary where the key represents
the index in the data list and the value is the array data values the index in the data list and the value is the array data values
@ -536,10 +576,9 @@ class PrestoEngineSpec(BaseEngineSpec):
return array_data_dict return array_data_dict
@classmethod @classmethod
def _process_array_data(cls, def _process_array_data(
data: List[dict], cls, data: List[dict], all_columns: List[dict], array_column_hierarchy: dict
all_columns: List[dict], ) -> dict:
array_column_hierarchy: dict) -> dict:
""" """
Pull out array data that is ready to be processed into a dictionary. Pull out array data that is ready to be processed into a dictionary.
The key refers to the index in the original data set. The value is The key refers to the index in the original data set. The value is
@ -575,38 +614,39 @@ class PrestoEngineSpec(BaseEngineSpec):
# Determine what columns are ready to be processed. This is necessary for # Determine what columns are ready to be processed. This is necessary for
# array columns that contain rows with nested arrays. We first process # array columns that contain rows with nested arrays. We first process
# the outer arrays before processing inner arrays. # the outer arrays before processing inner arrays.
array_columns_to_process, \ array_columns_to_process, unprocessed_array_columns = cls._split_array_columns_by_process_state(
unprocessed_array_columns = cls._split_array_columns_by_process_state( array_columns, array_column_hierarchy, data[0]
array_columns, array_column_hierarchy, data[0]) )
# Pull out array data that is ready to be processed into a dictionary. # Pull out array data that is ready to be processed into a dictionary.
all_array_data = cls._convert_data_list_to_array_data_dict( all_array_data = cls._convert_data_list_to_array_data_dict(
data, array_columns_to_process) data, array_columns_to_process
)
for original_data_index, expanded_array_data in all_array_data.items(): for original_data_index, expanded_array_data in all_array_data.items():
for array_column in array_columns: for array_column in array_columns:
if array_column in unprocessed_array_columns: if array_column in unprocessed_array_columns:
continue continue
# Expand array values that are rows # Expand array values that are rows
if str(array_column_hierarchy[array_column]['type']) == 'ROW': if str(array_column_hierarchy[array_column]["type"]) == "ROW":
for array_value in expanded_array_data: for array_value in expanded_array_data:
cls._expand_row_data(array_value, cls._expand_row_data(
array_column, array_value, array_column, array_column_hierarchy
array_column_hierarchy) )
continue continue
array_data = expanded_array_data[0][array_column] array_data = expanded_array_data[0][array_column]
array_children = array_column_hierarchy[array_column] array_children = array_column_hierarchy[array_column]
# This is an empty array of primitive data type # This is an empty array of primitive data type
if not array_data and not array_children['children']: if not array_data and not array_children["children"]:
continue continue
# Pull out complex array values into its own row of data # Pull out complex array values into its own row of data
elif array_data and array_children['children']: elif array_data and array_children["children"]:
for array_index, data_value in enumerate(array_data): for array_index, data_value in enumerate(array_data):
if array_index >= len(expanded_array_data): if array_index >= len(expanded_array_data):
empty_data = cls._create_empty_row_of_data(all_columns) empty_data = cls._create_empty_row_of_data(all_columns)
expanded_array_data.append(empty_data) expanded_array_data.append(empty_data)
for index, datum_value in enumerate(data_value): for index, datum_value in enumerate(data_value):
array_child = array_children['children'][index] array_child = array_children["children"][index]
expanded_array_data[array_index][array_child] = datum_value expanded_array_data[array_index][array_child] = datum_value
# Pull out primitive array values into its own row of data # Pull out primitive array values into its own row of data
elif array_data: elif array_data:
@ -617,15 +657,15 @@ class PrestoEngineSpec(BaseEngineSpec):
expanded_array_data[array_index][array_column] = data_value expanded_array_data[array_index][array_column] = data_value
# This is an empty array with nested fields # This is an empty array with nested fields
else: else:
for index, array_child in enumerate(array_children['children']): for index, array_child in enumerate(array_children["children"]):
for array_value in expanded_array_data: for array_value in expanded_array_data:
array_value[array_child] = '' array_value[array_child] = ""
return all_array_data return all_array_data
@classmethod @classmethod
def _consolidate_array_data_into_data(cls, def _consolidate_array_data_into_data(
data: List[dict], cls, data: List[dict], array_data: dict
array_data: dict) -> None: ) -> None:
""" """
Consolidate data given a list representing rows of data and a dictionary Consolidate data given a list representing rows of data and a dictionary
representing expanded array data representing expanded array data
@ -659,14 +699,14 @@ class PrestoEngineSpec(BaseEngineSpec):
while data_index < len(data): while data_index < len(data):
data[data_index].update(array_data[original_data_index][0]) data[data_index].update(array_data[original_data_index][0])
array_data[original_data_index].pop(0) array_data[original_data_index].pop(0)
data[data_index + 1:data_index + 1] = array_data[original_data_index] data[data_index + 1 : data_index + 1] = array_data[original_data_index]
data_index = data_index + len(array_data[original_data_index]) + 1 data_index = data_index + len(array_data[original_data_index]) + 1
original_data_index = original_data_index + 1 original_data_index = original_data_index + 1
@classmethod @classmethod
def _remove_processed_array_columns(cls, def _remove_processed_array_columns(
unprocessed_array_columns: Set[str], cls, unprocessed_array_columns: Set[str], array_column_hierarchy: dict
array_column_hierarchy: dict) -> None: ) -> None:
""" """
Remove keys representing array columns that have already been processed Remove keys representing array columns that have already been processed
:param unprocessed_array_columns: list of unprocessed array columns :param unprocessed_array_columns: list of unprocessed array columns
@ -680,9 +720,9 @@ class PrestoEngineSpec(BaseEngineSpec):
del array_column_hierarchy[array_column] del array_column_hierarchy[array_column]
@classmethod @classmethod
def expand_data(cls, def expand_data(
columns: List[dict], cls, columns: List[dict], data: List[dict]
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]: ) -> Tuple[List[dict], List[dict], List[dict]]:
""" """
We do not immediately display rows and arrays clearly in the data grid. This We do not immediately display rows and arrays clearly in the data grid. This
method separates out nested fields and data values to help clearly display method separates out nested fields and data values to help clearly display
@ -707,18 +747,18 @@ class PrestoEngineSpec(BaseEngineSpec):
all_columns: List[dict] = [] all_columns: List[dict] = []
# Get the list of all columns (selected fields and their nested fields) # Get the list of all columns (selected fields and their nested fields)
for column in columns: for column in columns:
if column['type'].startswith('ARRAY') or column['type'].startswith('ROW'): if column["type"].startswith("ARRAY") or column["type"].startswith("ROW"):
cls._parse_structural_column(column['name'], cls._parse_structural_column(
column['type'].lower(), column["name"], column["type"].lower(), all_columns
all_columns) )
else: else:
all_columns.append(column) all_columns.append(column)
# Build graphs where the root node is a row or array and its children are that # Build graphs where the root node is a row or array and its children are that
# column's nested fields # column's nested fields
row_column_hierarchy,\ row_column_hierarchy, array_column_hierarchy, expanded_columns = cls._create_row_and_array_hierarchy(
array_column_hierarchy,\ columns
expanded_columns = cls._create_row_and_array_hierarchy(columns) )
# Pull out a row's nested fields and their values into separate columns # Pull out a row's nested fields and their values into separate columns
ordered_row_columns = row_column_hierarchy.keys() ordered_row_columns = row_column_hierarchy.keys()
@ -729,17 +769,18 @@ class PrestoEngineSpec(BaseEngineSpec):
while array_column_hierarchy: while array_column_hierarchy:
array_columns = list(array_column_hierarchy.keys()) array_columns = list(array_column_hierarchy.keys())
# Determine what columns are ready to be processed. # Determine what columns are ready to be processed.
array_columns_to_process,\ array_columns_to_process, unprocessed_array_columns = cls._split_array_columns_by_process_state(
unprocessed_array_columns = cls._split_array_columns_by_process_state( array_columns, array_column_hierarchy, data[0]
array_columns, array_column_hierarchy, data[0]) )
all_array_data = cls._process_array_data(data, all_array_data = cls._process_array_data(
all_columns, data, all_columns, array_column_hierarchy
array_column_hierarchy) )
# Consolidate the original data set and the expanded array data # Consolidate the original data set and the expanded array data
cls._consolidate_array_data_into_data(data, all_array_data) cls._consolidate_array_data_into_data(data, all_array_data)
# Remove processed array columns from the graph # Remove processed array columns from the graph
cls._remove_processed_array_columns(unprocessed_array_columns, cls._remove_processed_array_columns(
array_column_hierarchy) unprocessed_array_columns, array_column_hierarchy
)
return all_columns, data, expanded_columns return all_columns, data, expanded_columns
@ -748,25 +789,26 @@ class PrestoEngineSpec(BaseEngineSpec):
indexes = database.get_indexes(table_name, schema_name) indexes = database.get_indexes(table_name, schema_name)
if not indexes: if not indexes:
return {} return {}
cols = indexes[0].get('column_names', []) cols = indexes[0].get("column_names", [])
full_table_name = table_name full_table_name = table_name
if schema_name and '.' not in table_name: if schema_name and "." not in table_name:
full_table_name = '{}.{}'.format(schema_name, table_name) full_table_name = "{}.{}".format(schema_name, table_name)
pql = cls._partition_query(full_table_name) pql = cls._partition_query(full_table_name)
col_name, latest_part = cls.latest_partition( col_name, latest_part = cls.latest_partition(
table_name, schema_name, database, show_first=True) table_name, schema_name, database, show_first=True
)
return { return {
'partitions': { "partitions": {
'cols': cols, "cols": cols,
'latest': {col_name: latest_part}, "latest": {col_name: latest_part},
'partitionQuery': pql, "partitionQuery": pql,
}, }
} }
@classmethod @classmethod
def handle_cursor(cls, cursor, query, session): def handle_cursor(cls, cursor, query, session):
"""Updates progress information""" """Updates progress information"""
logging.info('Polling the cursor for progress') logging.info("Polling the cursor for progress")
polled = cursor.poll() polled = cursor.poll()
# poll returns dict -- JSON status information or ``None`` # poll returns dict -- JSON status information or ``None``
# if the query is done # if the query is done
@ -774,7 +816,7 @@ class PrestoEngineSpec(BaseEngineSpec):
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178 # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
while polled: while polled:
# Update the object and wait for the kill signal. # Update the object and wait for the kill signal.
stats = polled.get('stats', {}) stats = polled.get("stats", {})
query = session.query(type(query)).filter_by(id=query.id).one() query = session.query(type(query)).filter_by(id=query.id).one()
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]: if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
@ -782,50 +824,51 @@ class PrestoEngineSpec(BaseEngineSpec):
break break
if stats: if stats:
state = stats.get('state') state = stats.get("state")
# if already finished, then stop polling # if already finished, then stop polling
if state == 'FINISHED': if state == "FINISHED":
break break
completed_splits = float(stats.get('completedSplits')) completed_splits = float(stats.get("completedSplits"))
total_splits = float(stats.get('totalSplits')) total_splits = float(stats.get("totalSplits"))
if total_splits and completed_splits: if total_splits and completed_splits:
progress = 100 * (completed_splits / total_splits) progress = 100 * (completed_splits / total_splits)
logging.info( logging.info(
'Query progress: {} / {} ' "Query progress: {} / {} "
'splits'.format(completed_splits, total_splits)) "splits".format(completed_splits, total_splits)
)
if progress > query.progress: if progress > query.progress:
query.progress = progress query.progress = progress
session.commit() session.commit()
time.sleep(1) time.sleep(1)
logging.info('Polling the cursor for progress') logging.info("Polling the cursor for progress")
polled = cursor.poll() polled = cursor.poll()
@classmethod @classmethod
def extract_error_message(cls, e): def extract_error_message(cls, e):
if ( if (
hasattr(e, 'orig') and hasattr(e, "orig")
type(e.orig).__name__ == 'DatabaseError' and and type(e.orig).__name__ == "DatabaseError"
isinstance(e.orig[0], dict)): and isinstance(e.orig[0], dict)
):
error_dict = e.orig[0] error_dict = e.orig[0]
return '{} at {}: {}'.format( return "{} at {}: {}".format(
error_dict.get('errorName'), error_dict.get("errorName"),
error_dict.get('errorLocation'), error_dict.get("errorLocation"),
error_dict.get('message'), error_dict.get("message"),
) )
if ( if (
type(e).__name__ == 'DatabaseError' and type(e).__name__ == "DatabaseError"
hasattr(e, 'args') and and hasattr(e, "args")
len(e.args) > 0 and len(e.args) > 0
): ):
error_dict = e.args[0] error_dict = e.args[0]
return error_dict.get('message') return error_dict.get("message")
return utils.error_msg_from_exception(e) return utils.error_msg_from_exception(e)
@classmethod @classmethod
def _partition_query( def _partition_query(cls, table_name, limit=0, order_by=None, filters=None):
cls, table_name, limit=0, order_by=None, filters=None):
"""Returns a partition query """Returns a partition query
:param table_name: the name of the table to get partitions from :param table_name: the name of the table to get partitions from
@ -838,42 +881,44 @@ class PrestoEngineSpec(BaseEngineSpec):
:type order_by: list of (str, bool) tuples :type order_by: list of (str, bool) tuples
:param filters: dict of field name and filter value combinations :param filters: dict of field name and filter value combinations
""" """
limit_clause = 'LIMIT {}'.format(limit) if limit else '' limit_clause = "LIMIT {}".format(limit) if limit else ""
order_by_clause = '' order_by_clause = ""
if order_by: if order_by:
l = [] # noqa: E741 l = [] # noqa: E741
for field, desc in order_by: for field, desc in order_by:
l.append(field + ' DESC' if desc else '') l.append(field + " DESC" if desc else "")
order_by_clause = 'ORDER BY ' + ', '.join(l) order_by_clause = "ORDER BY " + ", ".join(l)
where_clause = '' where_clause = ""
if filters: if filters:
l = [] # noqa: E741 l = [] # noqa: E741
for field, value in filters.items(): for field, value in filters.items():
l.append(f"{field} = '{value}'") l.append(f"{field} = '{value}'")
where_clause = 'WHERE ' + ' AND '.join(l) where_clause = "WHERE " + " AND ".join(l)
sql = textwrap.dedent(f"""\ sql = textwrap.dedent(
f"""\
SELECT * FROM "{table_name}$partitions" SELECT * FROM "{table_name}$partitions"
{where_clause} {where_clause}
{order_by_clause} {order_by_clause}
{limit_clause} {limit_clause}
""") """
)
return sql return sql
@classmethod @classmethod
def where_latest_partition( def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
cls, table_name, schema, database, qry, columns=None):
try: try:
col_name, value = cls.latest_partition( col_name, value = cls.latest_partition(
table_name, schema, database, show_first=True) table_name, schema, database, show_first=True
)
except Exception: except Exception:
# table is not partitioned # table is not partitioned
return False return False
if value is not None: if value is not None:
for c in columns: for c in columns:
if c.get('name') == col_name: if c.get("name") == col_name:
return qry.where(Column(col_name) == value) return qry.where(Column(col_name) == value)
return False return False
@ -900,15 +945,17 @@ class PrestoEngineSpec(BaseEngineSpec):
('ds', '2018-01-01') ('ds', '2018-01-01')
""" """
indexes = database.get_indexes(table_name, schema) indexes = database.get_indexes(table_name, schema)
if len(indexes[0]['column_names']) < 1: if len(indexes[0]["column_names"]) < 1:
raise SupersetTemplateException( raise SupersetTemplateException(
'The table should have one partitioned field') "The table should have one partitioned field"
elif not show_first and len(indexes[0]['column_names']) > 1: )
elif not show_first and len(indexes[0]["column_names"]) > 1:
raise SupersetTemplateException( raise SupersetTemplateException(
'The table should have a single partitioned field ' "The table should have a single partitioned field "
'to use this function. You may want to use ' "to use this function. You may want to use "
'`presto.latest_sub_partition`') "`presto.latest_sub_partition`"
part_field = indexes[0]['column_names'][0] )
part_field = indexes[0]["column_names"][0]
sql = cls._partition_query(table_name, 1, [(part_field, True)]) sql = cls._partition_query(table_name, 1, [(part_field, True)])
df = database.get_df(sql, schema) df = database.get_df(sql, schema)
return part_field, cls._latest_partition_from_df(df) return part_field, cls._latest_partition_from_df(df)
@ -941,15 +988,14 @@ class PrestoEngineSpec(BaseEngineSpec):
'2018-01-01' '2018-01-01'
""" """
indexes = database.get_indexes(table_name, schema) indexes = database.get_indexes(table_name, schema)
part_fields = indexes[0]['column_names'] part_fields = indexes[0]["column_names"]
for k in kwargs.keys(): for k in kwargs.keys():
if k not in k in part_fields: if k not in k in part_fields:
msg = 'Field [{k}] is not part of the portioning key' msg = "Field [{k}] is not part of the portioning key"
raise SupersetTemplateException(msg) raise SupersetTemplateException(msg)
if len(kwargs.keys()) != len(part_fields) - 1: if len(kwargs.keys()) != len(part_fields) - 1:
msg = ( msg = (
'A filter needs to be specified for {} out of the ' "A filter needs to be specified for {} out of the " "{} fields."
'{} fields.'
).format(len(part_fields) - 1, len(part_fields)) ).format(len(part_fields) - 1, len(part_fields))
raise SupersetTemplateException(msg) raise SupersetTemplateException(msg)
@ -957,9 +1003,8 @@ class PrestoEngineSpec(BaseEngineSpec):
if field not in kwargs.keys(): if field not in kwargs.keys():
field_to_return = field field_to_return = field
sql = cls._partition_query( sql = cls._partition_query(table_name, 1, [(field_to_return, True)], kwargs)
table_name, 1, [(field_to_return, True)], kwargs)
df = database.get_df(sql, schema) df = database.get_df(sql, schema)
if df.empty: if df.empty:
return '' return ""
return df.to_dict()[field_to_return][0] return df.to_dict()[field_to_return][0]

View File

@ -19,7 +19,7 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class RedshiftEngineSpec(PostgresBaseEngineSpec): class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift' engine = "redshift"
max_column_name_length = 127 max_column_name_length = 127
@staticmethod @staticmethod

View File

@ -21,38 +21,38 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class SnowflakeEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake' engine = "snowflake"
force_column_alias_quotes = True force_column_alias_quotes = True
max_column_name_length = 256 max_column_name_length = 256
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1S': "DATE_TRUNC('SECOND', {col})", "PT1S": "DATE_TRUNC('SECOND', {col})",
'PT1M': "DATE_TRUNC('MINUTE', {col})", "PT1M": "DATE_TRUNC('MINUTE', {col})",
'PT5M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \ "PT5M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
DATE_TRUNC('HOUR', {col}))", DATE_TRUNC('HOUR', {col}))",
'PT10M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \ "PT10M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
DATE_TRUNC('HOUR', {col}))", DATE_TRUNC('HOUR', {col}))",
'PT15M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \ "PT15M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
DATE_TRUNC('HOUR', {col}))", DATE_TRUNC('HOUR', {col}))",
'PT0.5H': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \ "PT0.5H": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
DATE_TRUNC('HOUR', {col}))", DATE_TRUNC('HOUR', {col}))",
'PT1H': "DATE_TRUNC('HOUR', {col})", "PT1H": "DATE_TRUNC('HOUR', {col})",
'P1D': "DATE_TRUNC('DAY', {col})", "P1D": "DATE_TRUNC('DAY', {col})",
'P1W': "DATE_TRUNC('WEEK', {col})", "P1W": "DATE_TRUNC('WEEK', {col})",
'P1M': "DATE_TRUNC('MONTH', {col})", "P1M": "DATE_TRUNC('MONTH', {col})",
'P0.25Y': "DATE_TRUNC('QUARTER', {col})", "P0.25Y": "DATE_TRUNC('QUARTER', {col})",
'P1Y': "DATE_TRUNC('YEAR', {col})", "P1Y": "DATE_TRUNC('YEAR', {col})",
} }
@classmethod @classmethod
def adjust_database_uri(cls, uri, selected_schema=None): def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database database = uri.database
if '/' in uri.database: if "/" in uri.database:
database = uri.database.split('/')[0] database = uri.database.split("/")[0]
if selected_schema: if selected_schema:
selected_schema = parse.quote(selected_schema, safe='') selected_schema = parse.quote(selected_schema, safe="")
uri.database = database + '/' + selected_schema uri.database = database + "/" + selected_schema
return uri return uri
@classmethod @classmethod

View File

@ -22,17 +22,17 @@ from superset.utils import core as utils
class SqliteEngineSpec(BaseEngineSpec): class SqliteEngineSpec(BaseEngineSpec):
engine = 'sqlite' engine = "sqlite"
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1H': "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))", "PT1H": "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))",
'P1D': 'DATE({col})', "P1D": "DATE({col})",
'P1W': "DATE({col}, -strftime('%W', {col}) || ' days')", "P1W": "DATE({col}, -strftime('%W', {col}) || ' days')",
'P1M': "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')", "P1M": "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')",
'P1Y': "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))", "P1Y": "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))",
'P1W/1970-01-03T00:00:00Z': "DATE({col}, 'weekday 6')", "P1W/1970-01-03T00:00:00Z": "DATE({col}, 'weekday 6')",
'1969-12-28T00:00:00Z/P1W': "DATE({col}, 'weekday 0', '-7 days')", "1969-12-28T00:00:00Z/P1W": "DATE({col}, 'weekday 0', '-7 days')",
} }
@classmethod @classmethod
@ -40,30 +40,37 @@ class SqliteEngineSpec(BaseEngineSpec):
return "datetime({col}, 'unixepoch')" return "datetime({col}, 'unixepoch')"
@classmethod @classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \ def get_all_datasource_names(
-> List[utils.DatasourceName]: cls, db, datasource_type: str
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled, ) -> List[utils.DatasourceName]:
cache_timeout=db.schema_cache_timeout, schemas = db.get_all_schema_names(
force=True) cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True,
)
schema = schemas[0] schema = schemas[0]
if datasource_type == 'table': if datasource_type == "table":
return db.get_all_table_names_in_schema( return db.get_all_table_names_in_schema(
schema=schema, force=True, schema=schema,
force=True,
cache=db.table_cache_enabled, cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout) cache_timeout=db.table_cache_timeout,
elif datasource_type == 'view': )
elif datasource_type == "view":
return db.get_all_view_names_in_schema( return db.get_all_view_names_in_schema(
schema=schema, force=True, schema=schema,
force=True,
cache=db.table_cache_enabled, cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout) cache_timeout=db.table_cache_timeout,
)
else: else:
raise Exception(f'Unsupported datasource_type: {datasource_type}') raise Exception(f"Unsupported datasource_type: {datasource_type}")
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
iso = dttm.isoformat().replace('T', ' ') iso = dttm.isoformat().replace("T", " ")
if '.' not in iso: if "." not in iso:
iso += '.000000' iso += ".000000"
return "'{}'".format(iso) return "'{}'".format(iso)
@classmethod @classmethod

View File

@ -20,23 +20,26 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class TeradataEngineSpec(BaseEngineSpec): class TeradataEngineSpec(BaseEngineSpec):
"""Dialect for Teradata DB.""" """Dialect for Teradata DB."""
engine = 'teradata'
engine = "teradata"
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 30 # since 14.10 this is 128 max_column_name_length = 30 # since 14.10 this is 128
time_grain_functions = { time_grain_functions = {
None: '{col}', None: "{col}",
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')", "PT1M": "TRUNC(CAST({col} as DATE), 'MI')",
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')", "PT1H": "TRUNC(CAST({col} as DATE), 'HH')",
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')", "P1D": "TRUNC(CAST({col} as DATE), 'DDD')",
'P1W': "TRUNC(CAST({col} as DATE), 'WW')", "P1W": "TRUNC(CAST({col} as DATE), 'WW')",
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')", "P1M": "TRUNC(CAST({col} as DATE), 'MONTH')",
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')", "P0.25Y": "TRUNC(CAST({col} as DATE), 'Q')",
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')", "P1Y": "TRUNC(CAST({col} as DATE), 'YEAR')",
} }
@classmethod @classmethod
def epoch_to_dttm(cls): def epoch_to_dttm(cls):
return "CAST(((CAST(DATE '1970-01-01' + ({col} / 86400) AS TIMESTAMP(0) " \ return (
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' " \ "CAST(((CAST(DATE '1970-01-01' + ({col} / 86400) AS TIMESTAMP(0) "
'HOUR TO SECOND) AS TIMESTAMP(0))' "AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' "
"HOUR TO SECOND) AS TIMESTAMP(0))"
)

View File

@ -19,4 +19,4 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class VerticaEngineSpec(PostgresBaseEngineSpec): class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica' engine = "vertica"

View File

@ -18,8 +18,7 @@
# TODO: contribute back to pyhive. # TODO: contribute back to pyhive.
def fetch_logs(self, max_rows=1024, def fetch_logs(self, max_rows=1024, orientation=None):
orientation=None):
"""Mocked. Retrieve the logs produced by the execution of the query. """Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after Can be called multiple times to fetch the logs produced after
the previous call. the previous call.
@ -29,18 +28,18 @@ def fetch_logs(self, max_rows=1024,
This is not a part of DB-API. This is not a part of DB-API.
""" """
from pyhive import hive # noqa from pyhive import hive # noqa
from TCLIService import ttypes # noqa from TCLIService import ttypes # noqa
from thrift import Thrift # pylint: disable=import-error from thrift import Thrift # pylint: disable=import-error
orientation = orientation or ttypes.TFetchOrientation.FETCH_NEXT orientation = orientation or ttypes.TFetchOrientation.FETCH_NEXT
try: try:
req = ttypes.TGetLogReq(operationHandle=self._operationHandle) req = ttypes.TGetLogReq(operationHandle=self._operationHandle)
logs = self._connection.client.GetLog(req).log logs = self._connection.client.GetLog(req).log
return logs return logs
# raised if Hive is used # raised if Hive is used
except (ttypes.TApplicationException, except (ttypes.TApplicationException, Thrift.TApplicationException):
Thrift.TApplicationException):
if self._state == self._STATE_NONE: if self._state == self._STATE_NONE:
raise hive.ProgrammingError('No query yet') raise hive.ProgrammingError("No query yet")
logs = [] logs = []
while True: while True:
req = ttypes.TFetchResultsReq( req = ttypes.TFetchResultsReq(
@ -51,11 +50,10 @@ def fetch_logs(self, max_rows=1024,
) )
response = self._connection.client.FetchResults(req) response = self._connection.client.FetchResults(req)
hive._check_status(response) hive._check_status(response)
assert not response.results.rows, \ assert not response.results.rows, "expected data in columnar format"
'expected data in columnar format'
assert len(response.results.columns) == 1, response.results.columns assert len(response.results.columns) == 1, response.results.columns
new_logs = hive._unwrap_column(response.results.columns[0]) new_logs = hive._unwrap_column(response.results.columns[0])
logs += new_logs logs += new_logs
if not new_logs: if not new_logs:
break break
return '\n'.join(logs) return "\n".join(logs)

View File

@ -36,7 +36,7 @@ def is_subselect(parsed):
if not parsed.is_group(): if not parsed.is_group():
return False return False
for item in parsed.tokens: for item in parsed.tokens:
if item.ttype is DML and item.value.upper() == 'SELECT': if item.ttype is DML and item.value.upper() == "SELECT":
return True return True
return False return False
@ -52,7 +52,7 @@ def extract_from_part(parsed):
raise StopIteration raise StopIteration
else: else:
yield item yield item
elif item.ttype is Keyword and item.value.upper() == 'FROM': elif item.ttype is Keyword and item.value.upper() == "FROM":
from_seen = True from_seen = True

View File

@ -20,8 +20,7 @@ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.forms import DynamicForm from flask_appbuilder.forms import DynamicForm
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from flask_wtf.file import FileAllowed, FileField, FileRequired from flask_wtf.file import FileAllowed, FileField, FileRequired
from wtforms import ( from wtforms import BooleanField, Field, IntegerField, SelectField, StringField
BooleanField, Field, IntegerField, SelectField, StringField)
from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import DataRequired, Length, NumberRange, Optional from wtforms.validators import DataRequired, Length, NumberRange, Optional
@ -36,13 +35,13 @@ class CommaSeparatedListField(Field):
def _value(self): def _value(self):
if self.data: if self.data:
return u', '.join(self.data) return u", ".join(self.data)
else: else:
return u'' return u""
def process_formdata(self, valuelist): def process_formdata(self, valuelist):
if valuelist: if valuelist:
self.data = [x.strip() for x in valuelist[0].split(',')] self.data = [x.strip() for x in valuelist[0].split(",")]
else: else:
self.data = [] self.data = []
@ -61,9 +60,9 @@ class CsvToDatabaseForm(DynamicForm):
# pylint: disable=E0211 # pylint: disable=E0211
def csv_allowed_dbs(): def csv_allowed_dbs():
csv_allowed_dbs = [] csv_allowed_dbs = []
csv_enabled_dbs = db.session.query( csv_enabled_dbs = (
models.Database).filter_by( db.session.query(models.Database).filter_by(allow_csv_upload=True).all()
allow_csv_upload=True).all() )
for csv_enabled_db in csv_enabled_dbs: for csv_enabled_db in csv_enabled_dbs:
if CsvToDatabaseForm.at_least_one_schema_is_allowed(csv_enabled_db): if CsvToDatabaseForm.at_least_one_schema_is_allowed(csv_enabled_db):
csv_allowed_dbs.append(csv_enabled_db) csv_allowed_dbs.append(csv_enabled_db)
@ -95,110 +94,132 @@ class CsvToDatabaseForm(DynamicForm):
b) if database supports schema b) if database supports schema
user is able to upload to schema in schemas_allowed_for_csv_upload user is able to upload to schema in schemas_allowed_for_csv_upload
""" """
if (security_manager.database_access(database) or if (
security_manager.all_datasource_access()): security_manager.database_access(database)
or security_manager.all_datasource_access()
):
return True return True
schemas = database.get_schema_access_for_csv_upload() schemas = database.get_schema_access_for_csv_upload()
if (schemas and if schemas and security_manager.schemas_accessible_by_user(
security_manager.schemas_accessible_by_user( database, schemas, False
database, schemas, False)): ):
return True return True
return False return False
name = StringField( name = StringField(
_('Table Name'), _("Table Name"),
description=_('Name of table to be created from csv data.'), description=_("Name of table to be created from csv data."),
validators=[DataRequired()], validators=[DataRequired()],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
csv_file = FileField( csv_file = FileField(
_('CSV File'), _("CSV File"),
description=_('Select a CSV file to be uploaded to a database.'), description=_("Select a CSV file to be uploaded to a database."),
validators=[ validators=[FileRequired(), FileAllowed(["csv"], _("CSV Files Only!"))],
FileRequired(), FileAllowed(['csv'], _('CSV Files Only!'))]) )
con = QuerySelectField( con = QuerySelectField(
_('Database'), _("Database"),
query_factory=csv_allowed_dbs, query_factory=csv_allowed_dbs,
get_pk=lambda a: a.id, get_label=lambda a: a.database_name) get_pk=lambda a: a.id,
get_label=lambda a: a.database_name,
)
schema = StringField( schema = StringField(
_('Schema'), _("Schema"),
description=_('Specify a schema (if database flavor supports this).'), description=_("Specify a schema (if database flavor supports this)."),
validators=[Optional()], validators=[Optional()],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
sep = StringField( sep = StringField(
_('Delimiter'), _("Delimiter"),
description=_('Delimiter used by CSV file (for whitespace use \s+).'), description=_("Delimiter used by CSV file (for whitespace use \\s+)."),
validators=[DataRequired()], validators=[DataRequired()],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
if_exists = SelectField( if_exists = SelectField(
_('Table Exists'), _("Table Exists"),
description=_( description=_(
'If table exists do one of the following: ' "If table exists do one of the following: "
'Fail (do nothing), Replace (drop and recreate table) ' "Fail (do nothing), Replace (drop and recreate table) "
'or Append (insert data).'), "or Append (insert data)."
),
choices=[ choices=[
('fail', _('Fail')), ('replace', _('Replace')), ("fail", _("Fail")),
('append', _('Append'))], ("replace", _("Replace")),
validators=[DataRequired()]) ("append", _("Append")),
],
validators=[DataRequired()],
)
header = IntegerField( header = IntegerField(
_('Header Row'), _("Header Row"),
description=_( description=_(
'Row containing the headers to use as ' "Row containing the headers to use as "
'column names (0 is first line of data). ' "column names (0 is first line of data). "
'Leave empty if there is no header row.'), "Leave empty if there is no header row."
),
validators=[Optional(), NumberRange(min=0)], validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
index_col = IntegerField( index_col = IntegerField(
_('Index Column'), _("Index Column"),
description=_( description=_(
'Column to use as the row labels of the ' "Column to use as the row labels of the "
'dataframe. Leave empty if no index column.'), "dataframe. Leave empty if no index column."
),
validators=[Optional(), NumberRange(min=0)], validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
mangle_dupe_cols = BooleanField( mangle_dupe_cols = BooleanField(
_('Mangle Duplicate Columns'), _("Mangle Duplicate Columns"),
description=_('Specify duplicate columns as "X.0, X.1".')) description=_('Specify duplicate columns as "X.0, X.1".'),
)
skipinitialspace = BooleanField( skipinitialspace = BooleanField(
_('Skip Initial Space'), _("Skip Initial Space"), description=_("Skip spaces after delimiter.")
description=_('Skip spaces after delimiter.')) )
skiprows = IntegerField( skiprows = IntegerField(
_('Skip Rows'), _("Skip Rows"),
description=_('Number of rows to skip at start of file.'), description=_("Number of rows to skip at start of file."),
validators=[Optional(), NumberRange(min=0)], validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
nrows = IntegerField( nrows = IntegerField(
_('Rows to Read'), _("Rows to Read"),
description=_('Number of rows of file to read.'), description=_("Number of rows of file to read."),
validators=[Optional(), NumberRange(min=0)], validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
skip_blank_lines = BooleanField( skip_blank_lines = BooleanField(
_('Skip Blank Lines'), _("Skip Blank Lines"),
description=_( description=_(
'Skip blank lines rather than interpreting them ' "Skip blank lines rather than interpreting them " "as NaN values."
'as NaN values.')) ),
)
parse_dates = CommaSeparatedListField( parse_dates = CommaSeparatedListField(
_('Parse Dates'), _("Parse Dates"),
description=_( description=_(
'A comma separated list of columns that should be ' "A comma separated list of columns that should be " "parsed as dates."
'parsed as dates.'), ),
filters=[filter_not_empty_values]) filters=[filter_not_empty_values],
)
infer_datetime_format = BooleanField( infer_datetime_format = BooleanField(
_('Infer Datetime Format'), _("Infer Datetime Format"),
description=_( description=_("Use Pandas to interpret the datetime format " "automatically."),
'Use Pandas to interpret the datetime format ' )
'automatically.'))
decimal = StringField( decimal = StringField(
_('Decimal Character'), _("Decimal Character"),
default='.', default=".",
description=_('Character to interpret as decimal point.'), description=_("Character to interpret as decimal point."),
validators=[Optional(), Length(min=1, max=1)], validators=[Optional(), Length(min=1, max=1)],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)
index = BooleanField( index = BooleanField(
_('Dataframe Index'), _("Dataframe Index"), description=_("Write dataframe index as a column.")
description=_('Write dataframe index as a column.')) )
index_label = StringField( index_label = StringField(
_('Column Label(s)'), _("Column Label(s)"),
description=_( description=_(
'Column label for index column(s). If None is given ' "Column label for index column(s). If None is given "
'and Dataframe Index is True, Index Names are used.'), "and Dataframe Index is True, Index Names are used."
),
validators=[Optional()], validators=[Optional()],
widget=BS3TextFieldWidget()) widget=BS3TextFieldWidget(),
)

View File

@ -31,14 +31,14 @@ from superset import app
config = app.config config = app.config
BASE_CONTEXT = { BASE_CONTEXT = {
'datetime': datetime, "datetime": datetime,
'random': random, "random": random,
'relativedelta': relativedelta, "relativedelta": relativedelta,
'time': time, "time": time,
'timedelta': timedelta, "timedelta": timedelta,
'uuid': uuid, "uuid": uuid,
} }
BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {})) BASE_CONTEXT.update(config.get("JINJA_CONTEXT_ADDONS", {}))
def url_param(param, default=None): def url_param(param, default=None):
@ -63,16 +63,16 @@ def url_param(param, default=None):
if request.args.get(param): if request.args.get(param):
return request.args.get(param, default) return request.args.get(param, default)
# Supporting POST as well as get # Supporting POST as well as get
if request.form.get('form_data'): if request.form.get("form_data"):
form_data = json.loads(request.form.get('form_data')) form_data = json.loads(request.form.get("form_data"))
url_params = form_data.get('url_params') or {} url_params = form_data.get("url_params") or {}
return url_params.get(param, default) return url_params.get(param, default)
return default return default
def current_user_id(): def current_user_id():
"""The id of the user who is currently logged in""" """The id of the user who is currently logged in"""
if hasattr(g, 'user') and g.user: if hasattr(g, "user") and g.user:
return g.user.id return g.user.id
@ -88,7 +88,8 @@ def filter_values(column, default=None):
This is useful if: This is useful if:
- you want to use a filter box to filter a query where the name of filter box - you want to use a filter box to filter a query where the name of filter box
column doesn't match the one in the select statement column doesn't match the one in the select statement
- you want to have the ability for filter inside the main query for speed purposes - you want to have the ability for filter inside the main query for speed
purposes
This searches for "filters" and "extra_filters" in form_data for a match This searches for "filters" and "extra_filters" in form_data for a match
@ -105,19 +106,19 @@ def filter_values(column, default=None):
:return: returns a list of filter values :return: returns a list of filter values
:type: list :type: list
""" """
form_data = json.loads(request.form.get('form_data', '{}')) form_data = json.loads(request.form.get("form_data", "{}"))
return_val = [] return_val = []
for filter_type in ['filters', 'extra_filters']: for filter_type in ["filters", "extra_filters"]:
if filter_type not in form_data: if filter_type not in form_data:
continue continue
for f in form_data[filter_type]: for f in form_data[filter_type]:
if f['col'] == column: if f["col"] == column:
if isinstance(f['val'], list): if isinstance(f["val"], list):
for v in f['val']: for v in f["val"]:
return_val.append(v) return_val.append(v)
else: else:
return_val.append(f['val']) return_val.append(f["val"])
if return_val: if return_val:
return return_val return return_val
@ -142,6 +143,7 @@ class BaseTemplateProcessor(object):
and are given access to the ``models.Database`` object and schema and are given access to the ``models.Database`` object and schema
name. For globally available methods use ``@classmethod``. name. For globally available methods use ``@classmethod``.
""" """
engine = None engine = None
def __init__(self, database=None, query=None, table=None, **kwargs): def __init__(self, database=None, query=None, table=None, **kwargs):
@ -153,11 +155,11 @@ class BaseTemplateProcessor(object):
elif table: elif table:
self.schema = table.schema self.schema = table.schema
self.context = { self.context = {
'url_param': url_param, "url_param": url_param,
'current_user_id': current_user_id, "current_user_id": current_user_id,
'current_username': current_username, "current_username": current_username,
'filter_values': filter_values, "filter_values": filter_values,
'form_data': {}, "form_data": {},
} }
self.context.update(kwargs) self.context.update(kwargs)
self.context.update(BASE_CONTEXT) self.context.update(BASE_CONTEXT)
@ -183,30 +185,30 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
The methods described here are namespaced under ``presto`` in the The methods described here are namespaced under ``presto`` in the
jinja context as in ``SELECT '{{ presto.some_macro_call() }}'`` jinja context as in ``SELECT '{{ presto.some_macro_call() }}'``
""" """
engine = 'presto'
engine = "presto"
@staticmethod @staticmethod
def _schema_table(table_name, schema): def _schema_table(table_name, schema):
if '.' in table_name: if "." in table_name:
schema, table_name = table_name.split('.') schema, table_name = table_name.split(".")
return table_name, schema return table_name, schema
def latest_partition(self, table_name): def latest_partition(self, table_name):
table_name, schema = self._schema_table(table_name, self.schema) table_name, schema = self._schema_table(table_name, self.schema)
return self.database.db_engine_spec.latest_partition( return self.database.db_engine_spec.latest_partition(
table_name, schema, self.database)[1] table_name, schema, self.database
)[1]
def latest_sub_partition(self, table_name, **kwargs): def latest_sub_partition(self, table_name, **kwargs):
table_name, schema = self._schema_table(table_name, self.schema) table_name, schema = self._schema_table(table_name, self.schema)
return self.database.db_engine_spec.latest_sub_partition( return self.database.db_engine_spec.latest_sub_partition(
table_name=table_name, table_name=table_name, schema=schema, database=self.database, **kwargs
schema=schema, )
database=self.database,
**kwargs)
class HiveTemplateProcessor(PrestoTemplateProcessor): class HiveTemplateProcessor(PrestoTemplateProcessor):
engine = 'hive' engine = "hive"
template_processors = {} template_processors = {}

View File

@ -20,8 +20,7 @@
def update_time_range(form_data): def update_time_range(form_data):
"""Move since and until to time_range.""" """Move since and until to time_range."""
if 'since' in form_data or 'until' in form_data: if "since" in form_data or "until" in form_data:
form_data['time_range'] = '{} : {}'.format( form_data["time_range"] = "{} : {}".format(
form_data.pop('since', '') or '', form_data.pop("since", "") or "", form_data.pop("until", "") or ""
form_data.pop('until', '') or '',
) )

View File

@ -29,16 +29,17 @@ config = context.config
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
logger = logging.getLogger('alembic.env') logger = logging.getLogger("alembic.env")
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support
# from myapp import mymodel # from myapp import mymodel
from flask import current_app from flask import current_app
config.set_main_option('sqlalchemy.url', config.set_main_option(
current_app.config.get('SQLALCHEMY_DATABASE_URI')) "sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI")
target_metadata = Base.metadata # pylint: disable=no-member )
target_metadata = Base.metadata # pylint: disable=no-member
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@ -77,32 +78,33 @@ def run_migrations_online():
# when there are no changes to the schema # when there are no changes to the schema
# reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html # reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html
def process_revision_directives(context, revision, directives): def process_revision_directives(context, revision, directives):
if getattr(config.cmd_opts, 'autogenerate', False): if getattr(config.cmd_opts, "autogenerate", False):
script = directives[0] script = directives[0]
if script.upgrade_ops.is_empty(): if script.upgrade_ops.is_empty():
directives[:] = [] directives[:] = []
logger.info('No changes in schema detected.') logger.info("No changes in schema detected.")
engine = engine_from_config(config.get_section(config.config_ini_section), engine = engine_from_config(
prefix='sqlalchemy.', config.get_section(config.config_ini_section),
poolclass=pool.NullPool) prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connection = engine.connect() connection = engine.connect()
kwargs = {} kwargs = {}
if engine.name in ('sqlite', 'mysql'): if engine.name in ("sqlite", "mysql"):
kwargs = { kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
'transaction_per_migration': True, configure_args = current_app.extensions["migrate"].configure_args
'transactional_ddl': True,
}
configure_args = current_app.extensions['migrate'].configure_args
if configure_args: if configure_args:
kwargs.update(configure_args) kwargs.update(configure_args)
context.configure(connection=connection, context.configure(
target_metadata=target_metadata, connection=connection,
# compare_type=True, target_metadata=target_metadata,
process_revision_directives=process_revision_directives, # compare_type=True,
**kwargs) process_revision_directives=process_revision_directives,
**kwargs
)
try: try:
with context.begin_transaction(): with context.begin_transaction():
@ -110,6 +112,7 @@ def run_migrations_online():
finally: finally:
connection.close() connection.close()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:

View File

@ -25,15 +25,15 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '0b1f1ab473c0' revision = "0b1f1ab473c0"
down_revision = '55e910a74826' down_revision = "55e910a74826"
def upgrade(): def upgrade():
with op.batch_alter_table('query') as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.add_column(sa.Column('extra_json', sa.Text(), nullable=True)) batch_op.add_column(sa.Column("extra_json", sa.Text(), nullable=True))
def downgrade(): def downgrade():
with op.batch_alter_table('query') as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.drop_column('extra_json') batch_op.drop_column("extra_json")

View File

@ -23,29 +23,30 @@ Create Date: 2018-08-06 14:38:18.965248
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '0c5070e96b57' revision = "0c5070e96b57"
down_revision = '7fcdcde0761c' down_revision = "7fcdcde0761c"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.create_table('user_attribute', op.create_table(
sa.Column('created_on', sa.DateTime(), nullable=True), "user_attribute",
sa.Column('changed_on', sa.DateTime(), nullable=True), sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False), sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('welcome_dashboard_id', sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column('created_by_fk', sa.Integer(), nullable=True), sa.Column("welcome_dashboard_id", sa.Integer(), nullable=True),
sa.Column('changed_by_fk', sa.Integer(), nullable=True), sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id'], ), sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(['welcome_dashboard_id'], ['dashboards.id'], ), sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint('id') sa.ForeignKeyConstraint(["welcome_dashboard_id"], ["dashboards.id"]),
sa.PrimaryKeyConstraint("id"),
) )
def downgrade(): def downgrade():
op.drop_table('user_attribute') op.drop_table("user_attribute")

View File

@ -27,42 +27,49 @@ from superset.utils.core import generic_find_constraint_name
import logging import logging
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1226819ee0e3' revision = "1226819ee0e3"
down_revision = '956a063c52b3' down_revision = "956a063c52b3"
naming_convention = { naming_convention = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
} }
def find_constraint_name(upgrade=True): def find_constraint_name(upgrade=True):
cols = {'column_name'} if upgrade else {'datasource_name'} cols = {"column_name"} if upgrade else {"datasource_name"}
return generic_find_constraint_name( return generic_find_constraint_name(
table='columns', columns=cols, referenced='datasources', db=db) table="columns", columns=cols, referenced="datasources", db=db
)
def upgrade(): def upgrade():
try: try:
constraint = find_constraint_name() constraint = find_constraint_name()
with op.batch_alter_table("columns", with op.batch_alter_table(
naming_convention=naming_convention) as batch_op: "columns", naming_convention=naming_convention
) as batch_op:
if constraint: if constraint:
batch_op.drop_constraint(constraint, type_="foreignkey") batch_op.drop_constraint(constraint, type_="foreignkey")
batch_op.create_foreign_key( batch_op.create_foreign_key(
'fk_columns_datasource_name_datasources', "fk_columns_datasource_name_datasources",
'datasources', "datasources",
['datasource_name'], ['datasource_name']) ["datasource_name"],
["datasource_name"],
)
except: except:
logging.warning( logging.warning("Could not find or drop constraint on `columns`")
"Could not find or drop constraint on `columns`")
def downgrade(): def downgrade():
constraint = find_constraint_name(False) or 'fk_columns_datasource_name_datasources' constraint = find_constraint_name(False) or "fk_columns_datasource_name_datasources"
with op.batch_alter_table("columns", with op.batch_alter_table(
naming_convention=naming_convention) as batch_op: "columns", naming_convention=naming_convention
) as batch_op:
batch_op.drop_constraint(constraint, type_="foreignkey") batch_op.drop_constraint(constraint, type_="foreignkey")
batch_op.create_foreign_key( batch_op.create_foreign_key(
'fk_columns_column_name_datasources', "fk_columns_column_name_datasources",
'datasources', "datasources",
['column_name'], ['datasource_name']) ["column_name"],
["datasource_name"],
)

View File

@ -23,16 +23,18 @@ Create Date: 2016-12-06 17:40:40.389652
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1296d28ec131' revision = "1296d28ec131"
down_revision = '6414e83d82b7' down_revision = "6414e83d82b7"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('datasources', sa.Column('params', sa.String(length=1000), nullable=True)) op.add_column(
"datasources", sa.Column("params", sa.String(length=1000), nullable=True)
)
def downgrade(): def downgrade():
op.drop_column('datasources', 'params') op.drop_column("datasources", "params")

View File

@ -23,17 +23,16 @@ Create Date: 2015-12-14 13:37:17.374852
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '12d55656cbca' revision = "12d55656cbca"
down_revision = '55179c7f25c7' down_revision = "55179c7f25c7"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('tables', sa.Column('is_featured', sa.Boolean(), nullable=True)) op.add_column("tables", sa.Column("is_featured", sa.Boolean(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('tables', 'is_featured') op.drop_column("tables", "is_featured")

View File

@ -28,15 +28,16 @@ from sqlalchemy.ext.declarative import declarative_base
from superset import db from superset import db
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '130915240929' revision = "130915240929"
down_revision = 'f231d82b9b26' down_revision = "f231d82b9b26"
Base = declarative_base() Base = declarative_base()
class Table(Base): class Table(Base):
"""Declarative class to do query in upgrade""" """Declarative class to do query in upgrade"""
__tablename__ = 'tables'
__tablename__ = "tables"
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
sql = sa.Column(sa.Text) sql = sa.Column(sa.Text)
is_sqllab_view = sa.Column(sa.Boolean()) is_sqllab_view = sa.Column(sa.Boolean())
@ -45,9 +46,9 @@ class Table(Base):
def upgrade(): def upgrade():
bind = op.get_bind() bind = op.get_bind()
op.add_column( op.add_column(
'tables', "tables",
sa.Column( sa.Column(
'is_sqllab_view', "is_sqllab_view",
sa.Boolean(), sa.Boolean(),
nullable=True, nullable=True,
default=False, default=False,
@ -67,4 +68,4 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_column('tables', 'is_sqllab_view') op.drop_column("tables", "is_sqllab_view")

View File

@ -23,8 +23,8 @@ Create Date: 2019-01-18 14:56:26.307684
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '18dc26817ad2' revision = "18dc26817ad2"
down_revision = ('8b70aa3d0f87', 'a33a03f16c4a') down_revision = ("8b70aa3d0f87", "a33a03f16c4a")
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa

View File

@ -25,106 +25,77 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '18e88e1cc004' revision = "18e88e1cc004"
down_revision = '430039611635' down_revision = "430039611635"
def upgrade(): def upgrade():
try: try:
op.alter_column( op.alter_column(
'clusters', 'changed_on', "clusters", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'clusters', 'created_on', "clusters", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), nullable=True) )
op.drop_constraint(None, 'columns', type_='foreignkey') op.drop_constraint(None, "columns", type_="foreignkey")
op.drop_constraint(None, 'columns', type_='foreignkey') op.drop_constraint(None, "columns", type_="foreignkey")
op.drop_column('columns', 'created_on') op.drop_column("columns", "created_on")
op.drop_column('columns', 'created_by_fk') op.drop_column("columns", "created_by_fk")
op.drop_column('columns', 'changed_on') op.drop_column("columns", "changed_on")
op.drop_column('columns', 'changed_by_fk') op.drop_column("columns", "changed_by_fk")
op.alter_column( op.alter_column(
'css_templates', 'changed_on', "css_templates", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'css_templates', 'created_on', "css_templates", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'dashboards', 'changed_on', "dashboards", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'dashboards', 'created_on', "dashboards", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True) op.create_unique_constraint(None, "dashboards", ["slug"])
op.create_unique_constraint(None, 'dashboards', ['slug'])
op.alter_column( op.alter_column(
'datasources', 'changed_by_fk', "datasources", "changed_by_fk", existing_type=sa.INTEGER(), nullable=True
existing_type=sa.INTEGER(), )
nullable=True)
op.alter_column( op.alter_column(
'datasources', 'changed_on', "datasources", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'datasources', 'created_by_fk', "datasources", "created_by_fk", existing_type=sa.INTEGER(), nullable=True
existing_type=sa.INTEGER(), )
nullable=True)
op.alter_column( op.alter_column(
'datasources', 'created_on', "datasources", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True) op.alter_column("dbs", "changed_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column("dbs", "created_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column( op.alter_column(
'dbs', 'changed_on', "slices", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'dbs', 'created_on', "slices", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'slices', 'changed_on', "sql_metrics", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'slices', 'created_on', "sql_metrics", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'sql_metrics', 'changed_on', "table_columns", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'sql_metrics', 'created_on', "table_columns", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'table_columns', 'changed_on', "tables", "changed_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True)
op.alter_column( op.alter_column(
'table_columns', 'created_on', "tables", "created_on", existing_type=sa.DATETIME(), nullable=True
existing_type=sa.DATETIME(), )
nullable=True) op.alter_column("url", "changed_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column( op.alter_column("url", "created_on", existing_type=sa.DATETIME(), nullable=True)
'tables', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'tables', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'url', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'url', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
except Exception: except Exception:
pass pass

View File

@ -23,20 +23,20 @@ Create Date: 2017-09-15 15:09:40.495345
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '19a814813610' revision = "19a814813610"
down_revision = 'ca69c70ec99b' down_revision = "ca69c70ec99b"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('metrics', sa.Column('warning_text', sa.Text(), nullable=True)) op.add_column("metrics", sa.Column("warning_text", sa.Text(), nullable=True))
op.add_column('sql_metrics', sa.Column('warning_text', sa.Text(), nullable=True)) op.add_column("sql_metrics", sa.Column("warning_text", sa.Text(), nullable=True))
def downgrade(): def downgrade():
with op.batch_alter_table('sql_metrics') as batch_op_sql_metrics: with op.batch_alter_table("sql_metrics") as batch_op_sql_metrics:
batch_op_sql_metrics.drop_column('warning_text') batch_op_sql_metrics.drop_column("warning_text")
with op.batch_alter_table('metrics') as batch_op_metrics: with op.batch_alter_table("metrics") as batch_op_metrics:
batch_op_metrics.drop_column('warning_text') batch_op_metrics.drop_column("warning_text")

View File

@ -29,14 +29,14 @@ import sqlalchemy as sa
from superset.utils.core import MediumText from superset.utils.core import MediumText
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1a1d627ebd8e' revision = "1a1d627ebd8e"
down_revision = '0c5070e96b57' down_revision = "0c5070e96b57"
def upgrade(): def upgrade():
with op.batch_alter_table('dashboards') as batch_op: with op.batch_alter_table("dashboards") as batch_op:
batch_op.alter_column( batch_op.alter_column(
'position_json', "position_json",
existing_type=sa.Text(), existing_type=sa.Text(),
type_=MediumText(), type_=MediumText(),
existing_nullable=True, existing_nullable=True,
@ -44,9 +44,9 @@ def upgrade():
def downgrade(): def downgrade():
with op.batch_alter_table('dashboards') as batch_op: with op.batch_alter_table("dashboards") as batch_op:
batch_op.alter_column( batch_op.alter_column(
'position_json', "position_json",
existing_type=MediumText(), existing_type=MediumText(),
type_=sa.Text(), type_=sa.Text(),
existing_nullable=True, existing_nullable=True,

View File

@ -23,20 +23,21 @@ Create Date: 2015-12-04 09:42:16.973264
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1a48a5411020' revision = "1a48a5411020"
down_revision = '289ce07647b' down_revision = "289ce07647b"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('dashboards', sa.Column('slug', sa.String(length=255), nullable=True)) op.add_column("dashboards", sa.Column("slug", sa.String(length=255), nullable=True))
try: try:
op.create_unique_constraint('idx_unique_slug', 'dashboards', ['slug']) op.create_unique_constraint("idx_unique_slug", "dashboards", ["slug"])
except: except:
pass pass
def downgrade(): def downgrade():
op.drop_constraint(None, 'dashboards', type_='unique') op.drop_constraint(None, "dashboards", type_="unique")
op.drop_column('dashboards', 'slug') op.drop_column("dashboards", "slug")

View File

@ -22,16 +22,16 @@ Create Date: 2016-03-25 14:35:44.642576
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1d2ddd543133' revision = "1d2ddd543133"
down_revision = 'd2424a248d63' down_revision = "d2424a248d63"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('logs', sa.Column('dt', sa.Date(), nullable=True)) op.add_column("logs", sa.Column("dt", sa.Date(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('logs', 'dt') op.drop_column("logs", "dt")

View File

@ -26,19 +26,21 @@ import sqlalchemy as sa
from sqlalchemy.sql import expression from sqlalchemy.sql import expression
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1d9e835a84f9' revision = "1d9e835a84f9"
down_revision = '3dda56f1c4c6' down_revision = "3dda56f1c4c6"
def upgrade(): def upgrade():
op.add_column( op.add_column(
'dbs', "dbs",
sa.Column( sa.Column(
'allow_csv_upload', "allow_csv_upload",
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
server_default=expression.true())) server_default=expression.true(),
),
)
def downgrade(): def downgrade():
op.drop_column('dbs', 'allow_csv_upload') op.drop_column("dbs", "allow_csv_upload")

View File

@ -23,15 +23,16 @@ Create Date: 2015-10-05 22:11:00.537054
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1e2841a4128' revision = "1e2841a4128"
down_revision = '5a7bad26f2a7' down_revision = "5a7bad26f2a7"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('table_columns', sa.Column('expression', sa.Text(), nullable=True)) op.add_column("table_columns", sa.Column("expression", sa.Text(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('table_columns', 'expression') op.drop_column("table_columns", "expression")

View File

@ -17,8 +17,7 @@
import json import json
from alembic import op from alembic import op
from sqlalchemy import ( from sqlalchemy import Column, Integer, or_, String, Text
Column, Integer, or_, String, Text)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from superset import db from superset import db
@ -32,14 +31,14 @@ Create Date: 2017-12-17 11:06:30.180267
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '21e88bc06c02' revision = "21e88bc06c02"
down_revision = '67a6ac9b727b' down_revision = "67a6ac9b727b"
Base = declarative_base() Base = declarative_base()
class Slice(Base): class Slice(Base):
__tablename__ = 'slices' __tablename__ = "slices"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
viz_type = Column(String(250)) viz_type = Column(String(250))
params = Column(Text) params = Column(Text)
@ -49,24 +48,27 @@ def upgrade():
bind = op.get_bind() bind = op.get_bind()
session = db.Session(bind=bind) session = db.Session(bind=bind)
for slc in session.query(Slice).filter(or_( for slc in session.query(Slice).filter(
Slice.viz_type.like('line'), Slice.viz_type.like('bar'))): or_(Slice.viz_type.like("line"), Slice.viz_type.like("bar"))
):
params = json.loads(slc.params) params = json.loads(slc.params)
layers = params.get('annotation_layers', []) layers = params.get("annotation_layers", [])
if layers: if layers:
new_layers = [] new_layers = []
for layer in layers: for layer in layers:
new_layers.append({ new_layers.append(
'annotationType': 'INTERVAL', {
'style': 'solid', "annotationType": "INTERVAL",
'name': 'Layer {}'.format(layer), "style": "solid",
'show': True, "name": "Layer {}".format(layer),
'overrides': {'since': None, 'until': None}, "show": True,
'value': layer, "overrides": {"since": None, "until": None},
'width': 1, "value": layer,
'sourceType': 'NATIVE', "width": 1,
}) "sourceType": "NATIVE",
params['annotation_layers'] = new_layers }
)
params["annotation_layers"] = new_layers
slc.params = json.dumps(params) slc.params = json.dumps(params)
session.merge(slc) session.merge(slc)
session.commit() session.commit()
@ -77,12 +79,13 @@ def downgrade():
bind = op.get_bind() bind = op.get_bind()
session = db.Session(bind=bind) session = db.Session(bind=bind)
for slc in session.query(Slice).filter(or_( for slc in session.query(Slice).filter(
Slice.viz_type.like('line'), Slice.viz_type.like('bar'))): or_(Slice.viz_type.like("line"), Slice.viz_type.like("bar"))
):
params = json.loads(slc.params) params = json.loads(slc.params)
layers = params.get('annotation_layers', []) layers = params.get("annotation_layers", [])
if layers: if layers:
params['annotation_layers'] = [layer['value'] for layer in layers] params["annotation_layers"] = [layer["value"] for layer in layers]
slc.params = json.dumps(params) slc.params = json.dumps(params)
session.merge(slc) session.merge(slc)
session.commit() session.commit()

View File

@ -23,20 +23,20 @@ Create Date: 2015-12-15 17:02:45.128709
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '2591d77e9831' revision = "2591d77e9831"
down_revision = '12d55656cbca' down_revision = "12d55656cbca"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
with op.batch_alter_table('tables') as batch_op: with op.batch_alter_table("tables") as batch_op:
batch_op.add_column(sa.Column('user_id', sa.Integer())) batch_op.add_column(sa.Column("user_id", sa.Integer()))
batch_op.create_foreign_key('user_id', 'ab_user', ['user_id'], ['id']) batch_op.create_foreign_key("user_id", "ab_user", ["user_id"], ["id"])
def downgrade(): def downgrade():
with op.batch_alter_table('tables') as batch_op: with op.batch_alter_table("tables") as batch_op:
batch_op.drop_constraint('user_id', type_='foreignkey') batch_op.drop_constraint("user_id", type_="foreignkey")
batch_op.drop_column('user_id') batch_op.drop_column("user_id")

View File

@ -23,8 +23,8 @@ Create Date: 2016-06-27 08:43:52.592242
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '27ae655e4247' revision = "27ae655e4247"
down_revision = 'd8bc074f7aad' down_revision = "d8bc074f7aad"
from alembic import op from alembic import op
from superset import db from superset import db
@ -32,41 +32,51 @@ from sqlalchemy.ext.declarative import declarative_base
from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.models.mixins import AuditMixin
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from flask_appbuilder import Model from flask_appbuilder import Model
from sqlalchemy import ( from sqlalchemy import Column, Integer, ForeignKey, Table
Column, Integer, ForeignKey, Table)
Base = declarative_base() Base = declarative_base()
class User(Base): class User(Base):
"""Declarative class to do query in upgrade""" """Declarative class to do query in upgrade"""
__tablename__ = 'ab_user'
__tablename__ = "ab_user"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
slice_user = Table('slice_user', Base.metadata,
Column('id', Integer, primary_key=True), slice_user = Table(
Column('user_id', Integer, ForeignKey('ab_user.id')), "slice_user",
Column('slice_id', Integer, ForeignKey('slices.id')) Base.metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
) )
dashboard_user = Table( dashboard_user = Table(
'dashboard_user', Base.metadata, "dashboard_user",
Column('id', Integer, primary_key=True), Base.metadata,
Column('user_id', Integer, ForeignKey('ab_user.id')), Column("id", Integer, primary_key=True),
Column('dashboard_id', Integer, ForeignKey('dashboards.id')) Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
) )
class Slice(Base, AuditMixin): class Slice(Base, AuditMixin):
"""Declarative class to do query in upgrade""" """Declarative class to do query in upgrade"""
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
owners = relationship("User", secondary=slice_user) owners = relationship("User", secondary=slice_user)
class Dashboard(Base, AuditMixin): class Dashboard(Base, AuditMixin):
"""Declarative class to do query in upgrade""" """Declarative class to do query in upgrade"""
__tablename__ = 'dashboards'
__tablename__ = "dashboards"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
owners = relationship("User", secondary=dashboard_user) owners = relationship("User", secondary=dashboard_user)
def upgrade(): def upgrade():
bind = op.get_bind() bind = op.get_bind()
session = db.Session(bind=bind) session = db.Session(bind=bind)

View File

@ -24,21 +24,18 @@ Create Date: 2015-11-21 11:18:00.650587
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import EncryptedType from sqlalchemy_utils import EncryptedType
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '289ce07647b' revision = "289ce07647b"
down_revision = '2929af7925ed' down_revision = "2929af7925ed"
def upgrade(): def upgrade():
op.add_column( op.add_column(
'dbs', "dbs", sa.Column("password", EncryptedType(sa.String(1024)), nullable=True)
sa.Column( )
'password',
EncryptedType(sa.String(1024)),
nullable=True))
def downgrade(): def downgrade():
op.drop_column('dbs', 'password') op.drop_column("dbs", "password")

View File

@ -23,17 +23,18 @@ Create Date: 2015-10-19 20:54:00.565633
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '2929af7925ed' revision = "2929af7925ed"
down_revision = '1e2841a4128' down_revision = "1e2841a4128"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('datasources', sa.Column('offset', sa.Integer(), nullable=True)) op.add_column("datasources", sa.Column("offset", sa.Integer(), nullable=True))
op.add_column('tables', sa.Column('offset', sa.Integer(), nullable=True)) op.add_column("tables", sa.Column("offset", sa.Integer(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('tables', 'offset') op.drop_column("tables", "offset")
op.drop_column('datasources', 'offset') op.drop_column("datasources", "offset")

View File

@ -25,31 +25,31 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '2fcdcb35e487' revision = "2fcdcb35e487"
down_revision = 'a6c18f869a4e' down_revision = "a6c18f869a4e"
def upgrade(): def upgrade():
op.create_table( op.create_table(
'saved_query', "saved_query",
sa.Column('created_on', sa.DateTime(), nullable=True), sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column('changed_on', sa.DateTime(), nullable=True), sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column('db_id', sa.Integer(), nullable=True), sa.Column("db_id", sa.Integer(), nullable=True),
sa.Column('label', sa.String(256), nullable=True), sa.Column("label", sa.String(256), nullable=True),
sa.Column('schema', sa.String(128), nullable=True), sa.Column("schema", sa.String(128), nullable=True),
sa.Column('sql', sa.Text(), nullable=True), sa.Column("sql", sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True), sa.Column("description", sa.Text(), nullable=True),
sa.Column('changed_by_fk', sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column('created_by_fk', sa.Integer(), nullable=True), sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.ForeignKeyConstraint(['db_id'], ['dbs.id'], ), sa.ForeignKeyConstraint(["db_id"], ["dbs.id"]),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint("id"),
) )
def downgrade(): def downgrade():
op.drop_table('saved_query') op.drop_table("saved_query")

View File

@ -23,8 +23,8 @@ Create Date: 2018-04-08 07:34:12.149910
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '30bb17c0dc76' revision = "30bb17c0dc76"
down_revision = 'f231d82b9b26' down_revision = "f231d82b9b26"
from datetime import date from datetime import date
@ -33,10 +33,10 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
with op.batch_alter_table('logs') as batch_op: with op.batch_alter_table("logs") as batch_op:
batch_op.drop_column('dt') batch_op.drop_column("dt")
def downgrade(): def downgrade():
with op.batch_alter_table('logs') as batch_op: with op.batch_alter_table("logs") as batch_op:
batch_op.add_column(sa.Column('dt', sa.Date, default=date.today())) batch_op.add_column(sa.Column("dt", sa.Date, default=date.today()))

View File

@ -23,24 +23,25 @@ Create Date: 2015-12-04 11:16:58.226984
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '315b3f4da9b0' revision = "315b3f4da9b0"
down_revision = '1a48a5411020' down_revision = "1a48a5411020"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.create_table('logs', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "logs",
sa.Column('action', sa.String(length=512), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True), sa.Column("action", sa.String(length=512), nullable=True),
sa.Column('json', sa.Text(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column('dttm', sa.DateTime(), nullable=True), sa.Column("json", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ), sa.Column("dttm", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
) )
def downgrade(): def downgrade():
op.drop_table('logs') op.drop_table("logs")

View File

@ -18,8 +18,7 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from superset import db from superset import db
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import ( from sqlalchemy import Column, Integer, String
Column, Integer, String)
"""update slice model """update slice model
@ -30,15 +29,16 @@ Create Date: 2016-09-07 23:50:59.366779
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '33d996bcc382' revision = "33d996bcc382"
down_revision = '41f6a59a61f2' down_revision = "41f6a59a61f2"
Base = declarative_base() Base = declarative_base()
class Slice(Base): class Slice(Base):
"""Declarative class to do query in upgrade""" """Declarative class to do query in upgrade"""
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
datasource_id = Column(Integer) datasource_id = Column(Integer)
druid_datasource_id = Column(Integer) druid_datasource_id = Column(Integer)
@ -48,7 +48,7 @@ class Slice(Base):
def upgrade(): def upgrade():
bind = op.get_bind() bind = op.get_bind()
op.add_column('slices', sa.Column('datasource_id', sa.Integer())) op.add_column("slices", sa.Column("datasource_id", sa.Integer()))
session = db.Session(bind=bind) session = db.Session(bind=bind)
for slc in session.query(Slice).all(): for slc in session.query(Slice).all():
@ -65,11 +65,11 @@ def downgrade():
bind = op.get_bind() bind = op.get_bind()
session = db.Session(bind=bind) session = db.Session(bind=bind)
for slc in session.query(Slice).all(): for slc in session.query(Slice).all():
if slc.datasource_type == 'druid': if slc.datasource_type == "druid":
slc.druid_datasource_id = slc.datasource_id slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == 'table': if slc.datasource_type == "table":
slc.table_id = slc.datasource_id slc.table_id = slc.datasource_id
session.merge(slc) session.merge(slc)
session.commit() session.commit()
session.close() session.close()
op.drop_column('slices', 'datasource_id') op.drop_column("slices", "datasource_id")

View File

@ -32,86 +32,101 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '3b626e2a6783' revision = "3b626e2a6783"
down_revision = 'eca4694defa7' down_revision = "eca4694defa7"
def upgrade(): def upgrade():
# cleanup after: https://github.com/airbnb/superset/pull/1078 # cleanup after: https://github.com/airbnb/superset/pull/1078
try: try:
slices_ibfk_1 = generic_find_constraint_name( slices_ibfk_1 = generic_find_constraint_name(
table='slices', columns={'druid_datasource_id'}, table="slices",
referenced='datasources', db=db) columns={"druid_datasource_id"},
referenced="datasources",
db=db,
)
slices_ibfk_2 = generic_find_constraint_name( slices_ibfk_2 = generic_find_constraint_name(
table='slices', columns={'table_id'}, table="slices", columns={"table_id"}, referenced="tables", db=db
referenced='tables', db=db) )
with op.batch_alter_table('slices') as batch_op: with op.batch_alter_table("slices") as batch_op:
if slices_ibfk_1: if slices_ibfk_1:
batch_op.drop_constraint(slices_ibfk_1, type_='foreignkey') batch_op.drop_constraint(slices_ibfk_1, type_="foreignkey")
if slices_ibfk_2: if slices_ibfk_2:
batch_op.drop_constraint(slices_ibfk_2, type_='foreignkey') batch_op.drop_constraint(slices_ibfk_2, type_="foreignkey")
batch_op.drop_column('druid_datasource_id') batch_op.drop_column("druid_datasource_id")
batch_op.drop_column('table_id') batch_op.drop_column("table_id")
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
# fixed issue: https://github.com/airbnb/superset/issues/466 # fixed issue: https://github.com/airbnb/superset/issues/466
try: try:
with op.batch_alter_table('columns') as batch_op: with op.batch_alter_table("columns") as batch_op:
batch_op.create_foreign_key( batch_op.create_foreign_key(
None, 'datasources', ['datasource_name'], ['datasource_name']) None, "datasources", ["datasource_name"], ["datasource_name"]
)
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
try: try:
with op.batch_alter_table('query') as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.create_unique_constraint('client_id', ['client_id']) batch_op.create_unique_constraint("client_id", ["client_id"])
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
try: try:
with op.batch_alter_table('query') as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.drop_column('name') batch_op.drop_column("name")
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
def downgrade(): def downgrade():
try: try:
with op.batch_alter_table('tables') as batch_op: with op.batch_alter_table("tables") as batch_op:
batch_op.create_index('table_name', ['table_name'], unique=True) batch_op.create_index("table_name", ["table_name"], unique=True)
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
try: try:
with op.batch_alter_table('slices') as batch_op: with op.batch_alter_table("slices") as batch_op:
batch_op.add_column(sa.Column( batch_op.add_column(
'table_id', mysql.INTEGER(display_width=11), sa.Column(
autoincrement=False, nullable=True)) "table_id",
batch_op.add_column(sa.Column( mysql.INTEGER(display_width=11),
'druid_datasource_id', sa.Integer(), autoincrement=False, autoincrement=False,
nullable=True)) nullable=True,
)
)
batch_op.add_column(
sa.Column(
"druid_datasource_id",
sa.Integer(),
autoincrement=False,
nullable=True,
)
)
batch_op.create_foreign_key( batch_op.create_foreign_key(
'slices_ibfk_1', 'datasources', ['druid_datasource_id'], "slices_ibfk_1", "datasources", ["druid_datasource_id"], ["id"]
['id']) )
batch_op.create_foreign_key( batch_op.create_foreign_key("slices_ibfk_2", "tables", ["table_id"], ["id"])
'slices_ibfk_2', 'tables', ['table_id'], ['id'])
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
try: try:
fk_columns = generic_find_constraint_name( fk_columns = generic_find_constraint_name(
table='columns', columns={'datasource_name'}, table="columns",
referenced='datasources', db=db) columns={"datasource_name"},
with op.batch_alter_table('columns') as batch_op: referenced="datasources",
batch_op.drop_constraint(fk_columns, type_='foreignkey') db=db,
)
with op.batch_alter_table("columns") as batch_op:
batch_op.drop_constraint(fk_columns, type_="foreignkey")
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))
op.add_column( op.add_column("query", sa.Column("name", sa.String(length=256), nullable=True))
'query', sa.Column('name', sa.String(length=256), nullable=True))
try: try:
with op.batch_alter_table('query') as batch_op: with op.batch_alter_table("query") as batch_op:
batch_op.drop_constraint('client_id', type_='unique') batch_op.drop_constraint("client_id", type_="unique")
except Exception as e: except Exception as e:
logging.warning(str(e)) logging.warning(str(e))

View File

@ -23,16 +23,16 @@ Create Date: 2016-08-18 14:06:28.784699
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '3c3ffe173e4f' revision = "3c3ffe173e4f"
down_revision = 'ad82a75afd82' down_revision = "ad82a75afd82"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('tables', sa.Column('sql', sa.Text(), nullable=True)) op.add_column("tables", sa.Column("sql", sa.Text(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('tables', 'sql') op.drop_column("tables", "sql")

View File

@ -35,45 +35,41 @@ from sqlalchemy import Column, Integer, String, Text
from superset import db from superset import db
from superset.utils.core import parse_human_timedelta from superset.utils.core import parse_human_timedelta
revision = '3dda56f1c4c6' revision = "3dda56f1c4c6"
down_revision = 'bddc498dd179' down_revision = "bddc498dd179"
Base = declarative_base() Base = declarative_base()
class Slice(Base): class Slice(Base):
__tablename__ = 'slices' __tablename__ = "slices"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
datasource_type = Column(String(200)) datasource_type = Column(String(200))
params = Column(Text) params = Column(Text)
comparison_type_map = { comparison_type_map = {"factor": "ratio", "growth": "percentage", "value": "absolute"}
'factor': 'ratio',
'growth': 'percentage',
'value': 'absolute',
}
db_engine_specs_map = { db_engine_specs_map = {
'second': 'PT1S', "second": "PT1S",
'minute': 'PT1M', "minute": "PT1M",
'5 minute': 'PT5M', "5 minute": "PT5M",
'10 minute': 'PT10M', "10 minute": "PT10M",
'half hour': 'PT0.5H', "half hour": "PT0.5H",
'hour': 'PT1H', "hour": "PT1H",
'day': 'P1D', "day": "P1D",
'week': 'P1W', "week": "P1W",
'week_ending_saturday': 'P1W', "week_ending_saturday": "P1W",
'week_start_sunday': 'P1W', "week_start_sunday": "P1W",
'week_start_monday': 'P1W', "week_start_monday": "P1W",
'week_starting_sunday': 'P1W', "week_starting_sunday": "P1W",
'P1W/1970-01-03T00:00:00Z': 'P1W', "P1W/1970-01-03T00:00:00Z": "P1W",
'1969-12-28T00:00:00Z/P1W': 'P1W', "1969-12-28T00:00:00Z/P1W": "P1W",
'month': 'P1M', "month": "P1M",
'quarter': 'P0.25Y', "quarter": "P0.25Y",
'year': 'P1Y', "year": "P1Y",
} }
@ -81,41 +77,36 @@ def isodate_duration_to_string(obj):
if obj.tdelta: if obj.tdelta:
if not obj.months and not obj.years: if not obj.months and not obj.years:
return format_seconds(obj.tdelta.total_seconds()) return format_seconds(obj.tdelta.total_seconds())
raise Exception('Unable to convert: {0}'.format(obj)) raise Exception("Unable to convert: {0}".format(obj))
if obj.months % 12 != 0: if obj.months % 12 != 0:
months = obj.months + 12 * obj.years months = obj.months + 12 * obj.years
return '{0} months'.format(months) return "{0} months".format(months)
return '{0} years'.format(obj.years + obj.months // 12) return "{0} years".format(obj.years + obj.months // 12)
def timedelta_to_string(obj): def timedelta_to_string(obj):
if obj.microseconds: if obj.microseconds:
raise Exception('Unable to convert: {0}'.format(obj)) raise Exception("Unable to convert: {0}".format(obj))
elif obj.seconds: elif obj.seconds:
return format_seconds(obj.total_seconds()) return format_seconds(obj.total_seconds())
elif obj.days % 7 == 0: elif obj.days % 7 == 0:
return '{0} weeks'.format(obj.days // 7) return "{0} weeks".format(obj.days // 7)
else: else:
return '{0} days'.format(obj.days) return "{0} days".format(obj.days)
def format_seconds(value): def format_seconds(value):
periods = [ periods = [("minute", 60), ("hour", 3600), ("day", 86400), ("week", 604800)]
('minute', 60),
('hour', 3600),
('day', 86400),
('week', 604800),
]
for period, multiple in periods: for period, multiple in periods:
if value % multiple == 0: if value % multiple == 0:
value //= multiple value //= multiple
break break
else: else:
period = 'second' period = "second"
return '{0} {1}{2}'.format(value, period, 's' if value > 1 else '') return "{0} {1}{2}".format(value, period, "s" if value > 1 else "")
def compute_time_compare(granularity, periods): def compute_time_compare(granularity, periods):
@ -129,11 +120,11 @@ def compute_time_compare(granularity, periods):
obj = isodate.parse_duration(granularity) * periods obj = isodate.parse_duration(granularity) * periods
except isodate.isoerror.ISO8601Error: except isodate.isoerror.ISO8601Error:
# if parse_human_timedelta can parse it, return it directly # if parse_human_timedelta can parse it, return it directly
delta = '{0} {1}{2}'.format(periods, granularity, 's' if periods > 1 else '') delta = "{0} {1}{2}".format(periods, granularity, "s" if periods > 1 else "")
obj = parse_human_timedelta(delta) obj = parse_human_timedelta(delta)
if obj: if obj:
return delta return delta
raise Exception('Unable to parse: {0}'.format(granularity)) raise Exception("Unable to parse: {0}".format(granularity))
if isinstance(obj, isodate.duration.Duration): if isinstance(obj, isodate.duration.Duration):
return isodate_duration_to_string(obj) return isodate_duration_to_string(obj)
@ -146,21 +137,24 @@ def upgrade():
session = db.Session(bind=bind) session = db.Session(bind=bind)
for chart in session.query(Slice): for chart in session.query(Slice):
params = json.loads(chart.params or '{}') params = json.loads(chart.params or "{}")
if not params.get('num_period_compare'): if not params.get("num_period_compare"):
continue continue
num_period_compare = int(params.get('num_period_compare')) num_period_compare = int(params.get("num_period_compare"))
granularity = (params.get('granularity') if chart.datasource_type == 'druid' granularity = (
else params.get('time_grain_sqla')) params.get("granularity")
if chart.datasource_type == "druid"
else params.get("time_grain_sqla")
)
time_compare = compute_time_compare(granularity, num_period_compare) time_compare = compute_time_compare(granularity, num_period_compare)
period_ratio_type = params.get('period_ratio_type') or 'growth' period_ratio_type = params.get("period_ratio_type") or "growth"
comparison_type = comparison_type_map[period_ratio_type.lower()] comparison_type = comparison_type_map[period_ratio_type.lower()]
params['time_compare'] = [time_compare] params["time_compare"] = [time_compare]
params['comparison_type'] = comparison_type params["comparison_type"] = comparison_type
chart.params = json.dumps(params, sort_keys=True) chart.params = json.dumps(params, sort_keys=True)
session.commit() session.commit()
@ -172,11 +166,11 @@ def downgrade():
session = db.Session(bind=bind) session = db.Session(bind=bind)
for chart in session.query(Slice): for chart in session.query(Slice):
params = json.loads(chart.params or '{}') params = json.loads(chart.params or "{}")
if 'time_compare' in params or 'comparison_type' in params: if "time_compare" in params or "comparison_type" in params:
params.pop('time_compare', None) params.pop("time_compare", None)
params.pop('comparison_type', None) params.pop("comparison_type", None)
chart.params = json.dumps(params, sort_keys=True) chart.params = json.dumps(params, sort_keys=True)
session.commit() session.commit()

View File

@ -26,57 +26,63 @@ Create Date: 2018-12-15 12:34:47.228756
from superset import db from superset import db
from superset.utils.core import generic_find_fk_constraint_name from superset.utils.core import generic_find_fk_constraint_name
revision = '3e1b21cd94a4' revision = "3e1b21cd94a4"
down_revision = '6c7537a6004a' down_revision = "6c7537a6004a"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
sqlatable_user = sa.Table( sqlatable_user = sa.Table(
'sqlatable_user', sa.MetaData(), "sqlatable_user",
sa.Column('id', sa.Integer, primary_key=True), sa.MetaData(),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')), sa.Column("id", sa.Integer, primary_key=True),
sa.Column('table_id', sa.Integer, sa.ForeignKey('tables.id')), sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
sa.Column("table_id", sa.Integer, sa.ForeignKey("tables.id")),
) )
SqlaTable = sa.Table( SqlaTable = sa.Table(
'tables', sa.MetaData(), "tables",
sa.Column('id', sa.Integer, primary_key=True), sa.MetaData(),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')), sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
) )
druiddatasource_user = sa.Table( druiddatasource_user = sa.Table(
'druiddatasource_user', sa.MetaData(), "druiddatasource_user",
sa.Column('id', sa.Integer, primary_key=True), sa.MetaData(),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')), sa.Column("id", sa.Integer, primary_key=True),
sa.Column('datasource_id', sa.Integer, sa.ForeignKey('datasources.id')), sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
sa.Column("datasource_id", sa.Integer, sa.ForeignKey("datasources.id")),
) )
DruidDatasource = sa.Table( DruidDatasource = sa.Table(
'datasources', sa.MetaData(), "datasources",
sa.Column('id', sa.Integer, primary_key=True), sa.MetaData(),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')), sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
) )
def upgrade(): def upgrade():
op.create_table('sqlatable_user', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "sqlatable_user",
sa.Column('user_id', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('table_id', sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['table_id'], ['tables.id'], ), sa.Column("table_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ), sa.ForeignKeyConstraint(["table_id"], ["tables.id"]),
sa.PrimaryKeyConstraint('id') sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
) sa.PrimaryKeyConstraint("id"),
op.create_table('druiddatasource_user', )
sa.Column('id', sa.Integer(), nullable=False), op.create_table(
sa.Column('user_id', sa.Integer(), nullable=True), "druiddatasource_user",
sa.Column('datasource_id', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['datasource_id'], ['datasources.id'], ), sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ), sa.Column("datasource_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.ForeignKeyConstraint(["datasource_id"], ["datasources.id"]),
) sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)
bind = op.get_bind() bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind) insp = sa.engine.reflection.Inspector.from_engine(bind)
@ -93,29 +99,31 @@ def upgrade():
for druiddatasource in druiddatasources: for druiddatasource in druiddatasources:
if druiddatasource.user_id is not None: if druiddatasource.user_id is not None:
session.execute( session.execute(
druiddatasource_user.insert().values(user_id=druiddatasource.user_id, datasource_id=druiddatasource.id) druiddatasource_user.insert().values(
user_id=druiddatasource.user_id, datasource_id=druiddatasource.id
)
) )
session.close() session.close()
with op.batch_alter_table('tables') as batch_op: with op.batch_alter_table("tables") as batch_op:
batch_op.drop_constraint('user_id', type_='foreignkey') batch_op.drop_constraint("user_id", type_="foreignkey")
batch_op.drop_column('user_id') batch_op.drop_column("user_id")
with op.batch_alter_table('datasources') as batch_op: with op.batch_alter_table("datasources") as batch_op:
batch_op.drop_constraint(generic_find_fk_constraint_name( batch_op.drop_constraint(
'datasources', generic_find_fk_constraint_name("datasources", {"id"}, "ab_user", insp),
{'id'}, type_="foreignkey",
'ab_user', )
insp, batch_op.drop_column("user_id")
), type_='foreignkey')
batch_op.drop_column('user_id')
def downgrade(): def downgrade():
op.drop_table('sqlatable_user') op.drop_table("sqlatable_user")
op.drop_table('druiddatasource_user') op.drop_table("druiddatasource_user")
with op.batch_alter_table('tables') as batch_op: with op.batch_alter_table("tables") as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True)) batch_op.add_column(sa.Column("user_id", sa.INTEGER(), nullable=True))
batch_op.create_foreign_key('user_id', 'ab_user', ['user_id'], ['id']) batch_op.create_foreign_key("user_id", "ab_user", ["user_id"], ["id"])
with op.batch_alter_table('datasources') as batch_op: with op.batch_alter_table("datasources") as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True)) batch_op.add_column(sa.Column("user_id", sa.INTEGER(), nullable=True))
batch_op.create_foreign_key('fk_datasources_user_id_ab_user', 'ab_user', ['user_id'], ['id']) batch_op.create_foreign_key(
"fk_datasources_user_id_ab_user", "ab_user", ["user_id"], ["id"]
)

View File

@ -25,20 +25,19 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '41f6a59a61f2' revision = "41f6a59a61f2"
down_revision = '3c3ffe173e4f' down_revision = "3c3ffe173e4f"
def upgrade(): def upgrade():
op.add_column('dbs', sa.Column('allow_ctas', sa.Boolean(), nullable=True)) op.add_column("dbs", sa.Column("allow_ctas", sa.Boolean(), nullable=True))
op.add_column("dbs", sa.Column("expose_in_sqllab", sa.Boolean(), nullable=True))
op.add_column( op.add_column(
'dbs', sa.Column('expose_in_sqllab', sa.Boolean(), nullable=True)) "dbs", sa.Column("force_ctas_schema", sa.String(length=250), nullable=True)
op.add_column( )
'dbs',
sa.Column('force_ctas_schema', sa.String(length=250), nullable=True))
def downgrade(): def downgrade():
op.drop_column('dbs', 'force_ctas_schema') op.drop_column("dbs", "force_ctas_schema")
op.drop_column('dbs', 'expose_in_sqllab') op.drop_column("dbs", "expose_in_sqllab")
op.drop_column('dbs', 'allow_ctas') op.drop_column("dbs", "allow_ctas")

View File

@ -25,15 +25,15 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '430039611635' revision = "430039611635"
down_revision = 'd827694c7555' down_revision = "d827694c7555"
def upgrade(): def upgrade():
op.add_column('logs', sa.Column('dashboard_id', sa.Integer(), nullable=True)) op.add_column("logs", sa.Column("dashboard_id", sa.Integer(), nullable=True))
op.add_column('logs', sa.Column('slice_id', sa.Integer(), nullable=True)) op.add_column("logs", sa.Column("slice_id", sa.Integer(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('logs', 'slice_id') op.drop_column("logs", "slice_id")
op.drop_column('logs', 'dashboard_id') op.drop_column("logs", "dashboard_id")

View File

@ -23,16 +23,16 @@ Create Date: 2016-01-18 23:43:16.073483
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '43df8de3a5f4' revision = "43df8de3a5f4"
down_revision = '7dbf98566af7' down_revision = "7dbf98566af7"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('dashboards', sa.Column('json_metadata', sa.Text(), nullable=True)) op.add_column("dashboards", sa.Column("json_metadata", sa.Text(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('dashboards', 'json_metadata') op.drop_column("dashboards", "json_metadata")

View File

@ -23,8 +23,8 @@ Create Date: 2018-06-13 10:20:35.846744
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '4451805bbaa1' revision = "4451805bbaa1"
down_revision = 'bddc498dd179' down_revision = "bddc498dd179"
from alembic import op from alembic import op
@ -38,23 +38,23 @@ Base = declarative_base()
class Slice(Base): class Slice(Base):
__tablename__ = 'slices' __tablename__ = "slices"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
datasource_id = Column(Integer, ForeignKey('tables.id')) datasource_id = Column(Integer, ForeignKey("tables.id"))
datasource_type = Column(String(200)) datasource_type = Column(String(200))
params = Column(Text) params = Column(Text)
class Table(Base): class Table(Base):
__tablename__ = 'tables' __tablename__ = "tables"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
database_id = Column(Integer, ForeignKey('dbs.id')) database_id = Column(Integer, ForeignKey("dbs.id"))
class Database(Base): class Database(Base):
__tablename__ = 'dbs' __tablename__ = "dbs"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
sqlalchemy_uri = Column(String(1024)) sqlalchemy_uri = Column(String(1024))
@ -68,7 +68,7 @@ def replace(source, target):
session.query(Slice, Database) session.query(Slice, Database)
.join(Table, Slice.datasource_id == Table.id) .join(Table, Slice.datasource_id == Table.id)
.join(Database, Table.database_id == Database.id) .join(Database, Table.database_id == Database.id)
.filter(Slice.datasource_type == 'table') .filter(Slice.datasource_type == "table")
.all() .all()
) )
@ -79,11 +79,11 @@ def replace(source, target):
if engine.dialect.identifier_preparer._double_percents: if engine.dialect.identifier_preparer._double_percents:
params = json.loads(slc.params) params = json.loads(slc.params)
if 'adhoc_filters' in params: if "adhoc_filters" in params:
for filt in params['adhoc_filters']: for filt in params["adhoc_filters"]:
if 'sqlExpression' in filt: if "sqlExpression" in filt:
filt['sqlExpression'] = ( filt["sqlExpression"] = filt["sqlExpression"].replace(
filt['sqlExpression'].replace(source, target) source, target
) )
slc.params = json.dumps(params, sort_keys=True) slc.params = json.dumps(params, sort_keys=True)
@ -95,8 +95,8 @@ def replace(source, target):
def upgrade(): def upgrade():
replace('%%', '%') replace("%%", "%")
def downgrade(): def downgrade():
replace('%', '%%') replace("%", "%%")

View File

@ -23,22 +23,21 @@ Create Date: 2016-09-12 23:33:14.789632
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '4500485bde7d' revision = "4500485bde7d"
down_revision = '41f6a59a61f2' down_revision = "41f6a59a61f2"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('dbs', sa.Column('allow_run_async', sa.Boolean(), nullable=True)) op.add_column("dbs", sa.Column("allow_run_async", sa.Boolean(), nullable=True))
op.add_column('dbs', sa.Column('allow_run_sync', sa.Boolean(), nullable=True)) op.add_column("dbs", sa.Column("allow_run_sync", sa.Boolean(), nullable=True))
def downgrade(): def downgrade():
try: try:
op.drop_column('dbs', 'allow_run_sync') op.drop_column("dbs", "allow_run_sync")
op.drop_column('dbs', 'allow_run_async') op.drop_column("dbs", "allow_run_async")
except Exception: except Exception:
pass pass

View File

@ -23,8 +23,8 @@ Create Date: 2019-02-16 17:44:44.493427
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '45e7da7cfeba' revision = "45e7da7cfeba"
down_revision = ('e553e78e90c5', 'c82ee8a39623') down_revision = ("e553e78e90c5", "c82ee8a39623")
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa

Some files were not shown because too many files have changed in this diff Show More