[format] Using Black (#7769)

This commit is contained in:
John Bodley 2019-06-25 13:34:48 -07:00 committed by GitHub
parent 0c9e6d0985
commit 5c58fd1802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
270 changed files with 15592 additions and 14772 deletions

22
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,22 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
repos:
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.6

View File

@ -81,7 +81,7 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=standarderror-builtin,long-builtin,dict-view-method,intern-builtin,suppressed-message,no-absolute-import,unpacking-in-except,apply-builtin,delslice-method,indexing-exception,old-raise-syntax,print-statement,cmp-builtin,reduce-builtin,useless-suppression,coerce-method,input-builtin,cmp-method,raw_input-builtin,nonzero-method,backtick,basestring-builtin,setslice-method,reload-builtin,oct-method,map-builtin-not-iterating,execfile-builtin,old-octal-literal,zip-builtin-not-iterating,buffer-builtin,getslice-method,metaclass-assignment,xrange-builtin,long-suffix,round-builtin,range-builtin-not-iterating,next-method-called,dict-iter-method,parameter-unpacking,unicode-builtin,unichr-builtin,import-star-module-level,raising-string,filter-builtin-not-iterating,old-ne-operator,using-cmp-argument,coerce-builtin,file-builtin,old-division,hex-method,invalid-unary-operand-type,missing-docstring,too-many-lines,duplicate-code
disable=standarderror-builtin,long-builtin,dict-view-method,intern-builtin,suppressed-message,no-absolute-import,unpacking-in-except,apply-builtin,delslice-method,indexing-exception,old-raise-syntax,print-statement,cmp-builtin,reduce-builtin,useless-suppression,coerce-method,input-builtin,cmp-method,raw_input-builtin,nonzero-method,backtick,basestring-builtin,setslice-method,reload-builtin,oct-method,map-builtin-not-iterating,execfile-builtin,old-octal-literal,zip-builtin-not-iterating,buffer-builtin,getslice-method,metaclass-assignment,xrange-builtin,long-suffix,round-builtin,range-builtin-not-iterating,next-method-called,dict-iter-method,parameter-unpacking,unicode-builtin,unichr-builtin,import-star-module-level,raising-string,filter-builtin-not-iterating,old-ne-operator,using-cmp-argument,coerce-builtin,file-builtin,old-division,hex-method,invalid-unary-operand-type,missing-docstring,too-many-lines,duplicate-code,bad-continuation
[REPORTS]
@ -209,7 +209,7 @@ max-nested-blocks=5
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=90
max-line-length=88
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$

View File

@ -296,9 +296,9 @@ python setup.py build_sphinx
#### OS Dependencies
Make sure your machine meets the [OS dependencies](https://superset.incubator.apache.org/installation.html#os-dependencies) before following these steps.
Make sure your machine meets the [OS dependencies](https://superset.incubator.apache.org/installation.html#os-dependencies) before following these steps.
Developers should use a virtualenv.
Developers should use a virtualenv.
```
pip install virtualenv
@ -447,6 +447,15 @@ export enum FeatureFlag {
those specified under FEATURE_FLAGS in `superset_config.py`. For example, `DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False }` in `superset/config.py` and `FEATURE_FLAGS = { 'BAR': True, 'BAZ': True }` in `superset_config.py` will result
in combined feature flags of `{ 'FOO': True, 'BAR': True, 'BAZ': True }`.
## Git Hooks
Superset uses Git pre-commit hooks courtesy of [pre-commit](https://pre-commit.com/). To install run the following:
```bash
pip3 install -r requirements-dev.txt
pre-commit install
```
## Linting
Lint the project with:
@ -461,6 +470,10 @@ npm ci
npm run lint
```
The Python code is auto-formatted using [Black](https://github.com/python/black) which
is configured as a pre-commit hook. There are also numerous [editor integrations](https://black.readthedocs.io/en/stable/editor_integration.html).
## Testing
### Python Testing
@ -736,7 +749,7 @@ to work on `async` related features.
To do this, you'll need to:
* Add an additional database entry. We recommend you copy the connection
string from the database labeled `main`, and then enable `SQL Lab` and the
string from the database labeled `main`, and then enable `SQL Lab` and the
features you want to use. Don't forget to check the `Async` box
* Configure a results backend, here's a local `FileSystemCache` example,
not recommended for production,

View File

@ -14,17 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
black==19.3b0
coverage==4.5.3
flake8-commas==2.0.0
flake8-import-order==0.18.1
flake8-mypy==17.8.0
flake8-quotes==2.0.1
flake8==3.7.7
flask-cors==3.0.7
ipdb==0.12
mypy==0.670
nose==1.3.7
pip-tools==3.7.0
pre-commit==1.17.0
psycopg2-binary==2.7.5
pycodestyle==2.5.0
pyhive==0.6.1

145
setup.py
View File

@ -23,113 +23,100 @@ import sys
from setuptools import find_packages, setup
if sys.version_info < (3, 6):
sys.exit('Sorry, Python < 3.6 is not supported')
sys.exit("Sorry, Python < 3.6 is not supported")
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
PACKAGE_DIR = os.path.join(BASE_DIR, 'superset', 'static', 'assets')
PACKAGE_FILE = os.path.join(PACKAGE_DIR, 'package.json')
PACKAGE_DIR = os.path.join(BASE_DIR, "superset", "static", "assets")
PACKAGE_FILE = os.path.join(PACKAGE_DIR, "package.json")
with open(PACKAGE_FILE) as package_file:
version_string = json.load(package_file)['version']
version_string = json.load(package_file)["version"]
with io.open('README.md', encoding='utf-8') as f:
with io.open("README.md", encoding="utf-8") as f:
long_description = f.read()
def get_git_sha():
try:
s = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
s = subprocess.check_output(["git", "rev-parse", "HEAD"])
return s.decode().strip()
except Exception:
return ''
return ""
GIT_SHA = get_git_sha()
version_info = {
'GIT_SHA': GIT_SHA,
'version': version_string,
}
print('-==-' * 15)
print('VERSION: ' + version_string)
print('GIT SHA: ' + GIT_SHA)
print('-==-' * 15)
version_info = {"GIT_SHA": GIT_SHA, "version": version_string}
print("-==-" * 15)
print("VERSION: " + version_string)
print("GIT SHA: " + GIT_SHA)
print("-==-" * 15)
with open(os.path.join(PACKAGE_DIR, 'version_info.json'), 'w') as version_file:
with open(os.path.join(PACKAGE_DIR, "version_info.json"), "w") as version_file:
json.dump(version_info, version_file)
setup(
name='apache-superset',
description=(
'A modern, enterprise-ready business intelligence web application'),
name="apache-superset",
description=("A modern, enterprise-ready business intelligence web application"),
long_description=long_description,
long_description_content_type='text/markdown',
long_description_content_type="text/markdown",
version=version_string,
packages=find_packages(),
include_package_data=True,
zip_safe=False,
scripts=['superset/bin/superset'],
scripts=["superset/bin/superset"],
install_requires=[
'bleach>=3.0.2, <4.0.0',
'celery>=4.2.0, <5.0.0',
'click>=6.0, <7.0.0', # `click`>=7 forces "-" instead of "_"
'colorama',
'contextlib2',
'croniter>=0.3.28',
'cryptography>=2.4.2',
'flask>=1.0.0, <2.0.0',
'flask-appbuilder>=2.1.5, <2.3.0',
'flask-caching',
'flask-compress',
'flask-talisman',
'flask-migrate',
'flask-wtf',
'geopy',
'gunicorn', # deprecated
'humanize',
'idna',
'isodate',
'markdown>=3.0',
'pandas>=0.18.0, <0.24.0', # `pandas`>=0.24.0 changes datetimelike API
'parsedatetime',
'pathlib2',
'polyline',
'pydruid>=0.5.2',
'python-dateutil',
'python-dotenv',
'python-geohash',
'pyyaml>=5.1',
'requests>=2.22.0',
'retry>=0.9.2',
'selenium>=3.141.0',
'simplejson>=3.15.0',
'sqlalchemy>=1.3.5,<2.0',
'sqlalchemy-utils>=0.33.2',
'sqlparse',
'wtforms-json',
"bleach>=3.0.2, <4.0.0",
"celery>=4.2.0, <5.0.0",
"click>=6.0, <7.0.0", # `click`>=7 forces "-" instead of "_"
"colorama",
"contextlib2",
"croniter>=0.3.28",
"cryptography>=2.4.2",
"flask>=1.0.0, <2.0.0",
"flask-appbuilder>=2.1.5, <2.3.0",
"flask-caching",
"flask-compress",
"flask-talisman",
"flask-migrate",
"flask-wtf",
"geopy",
"gunicorn", # deprecated
"humanize",
"idna",
"isodate",
"markdown>=3.0",
"pandas>=0.18.0, <0.24.0", # `pandas`>=0.24.0 changes datetimelike API
"parsedatetime",
"pathlib2",
"polyline",
"pydruid>=0.5.2",
"python-dateutil",
"python-dotenv",
"python-geohash",
"pyyaml>=5.1",
"requests>=2.22.0",
"retry>=0.9.2",
"selenium>=3.141.0",
"simplejson>=3.15.0",
"sqlalchemy>=1.3.5,<2.0",
"sqlalchemy-utils>=0.33.2",
"sqlparse",
"wtforms-json",
],
extras_require={
'bigquery': [
'pybigquery>=0.4.10',
'pandas_gbq>=0.10.0',
],
'cors': ['flask-cors>=2.0.0'],
'gsheets': ['gsheetsdb>=0.1.9'],
'hive': [
'pyhive[hive]>=0.6.1',
'tableschema',
'thrift>=0.11.0, <1.0.0',
],
'mysql': ['mysqlclient==1.4.2.post1'],
'postgres': ['psycopg2-binary==2.7.5'],
'presto': ['pyhive[presto]>=0.4.0'],
"bigquery": ["pybigquery>=0.4.10", "pandas_gbq>=0.10.0"],
"cors": ["flask-cors>=2.0.0"],
"gsheets": ["gsheetsdb>=0.1.9"],
"hive": ["pyhive[hive]>=0.6.1", "tableschema", "thrift>=0.11.0, <1.0.0"],
"mysql": ["mysqlclient==1.4.2.post1"],
"postgres": ["psycopg2-binary==2.7.5"],
"presto": ["pyhive[presto]>=0.4.0"],
},
author='Apache Software Foundation',
author_email='dev@superset.incubator.apache.org',
url='https://superset.apache.org/',
author="Apache Software Foundation",
author_email="dev@superset.incubator.apache.org",
url="https://superset.apache.org/",
download_url=(
'https://dist.apache.org/repos/dist/release/superset/' + version_string
"https://dist.apache.org/repos/dist/release/superset/" + version_string
),
classifiers=[
'Programming Language :: Python :: 3.6',
],
classifiers=["Programming Language :: Python :: 3.6"],
)

View File

@ -40,7 +40,7 @@ from superset.utils.core import pessimistic_connection_handling, setup_cache
wtforms_json.init()
APP_DIR = os.path.dirname(__file__)
CONFIG_MODULE = os.environ.get('SUPERSET_CONFIG', 'superset.config')
CONFIG_MODULE = os.environ.get("SUPERSET_CONFIG", "superset.config")
if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR)
@ -52,18 +52,18 @@ conf = app.config
#################################################################
# Handling manifest file logic at app start
#################################################################
MANIFEST_FILE = APP_DIR + '/static/assets/dist/manifest.json'
MANIFEST_FILE = APP_DIR + "/static/assets/dist/manifest.json"
manifest = {}
def parse_manifest_json():
global manifest
try:
with open(MANIFEST_FILE, 'r') as f:
with open(MANIFEST_FILE, "r") as f:
# the manifest inclues non-entry files
# we only need entries in templates
full_manifest = json.load(f)
manifest = full_manifest.get('entrypoints', {})
manifest = full_manifest.get("entrypoints", {})
except Exception:
pass
@ -72,14 +72,14 @@ def get_js_manifest_files(filename):
if app.debug:
parse_manifest_json()
entry_files = manifest.get(filename, {})
return entry_files.get('js', [])
return entry_files.get("js", [])
def get_css_manifest_files(filename):
if app.debug:
parse_manifest_json()
entry_files = manifest.get(filename, {})
return entry_files.get('css', [])
return entry_files.get("css", [])
def get_unloaded_chunks(files, loaded_chunks):
@ -104,16 +104,16 @@ def get_manifest():
#################################################################
for bp in conf.get('BLUEPRINTS'):
for bp in conf.get("BLUEPRINTS"):
try:
print("Registering blueprint: '{}'".format(bp.name))
app.register_blueprint(bp)
except Exception as e:
print('blueprint registration failed')
print("blueprint registration failed")
logging.exception(e)
if conf.get('SILENCE_FAB'):
logging.getLogger('flask_appbuilder').setLevel(logging.ERROR)
if conf.get("SILENCE_FAB"):
logging.getLogger("flask_appbuilder").setLevel(logging.ERROR)
if app.debug:
app.logger.setLevel(logging.DEBUG) # pylint: disable=no-member
@ -121,44 +121,46 @@ else:
# In production mode, add log handler to sys.stderr.
app.logger.addHandler(logging.StreamHandler()) # pylint: disable=no-member
app.logger.setLevel(logging.INFO) # pylint: disable=no-member
logging.getLogger('pyhive.presto').setLevel(logging.INFO)
logging.getLogger("pyhive.presto").setLevel(logging.INFO)
db = SQLA(app)
if conf.get('WTF_CSRF_ENABLED'):
if conf.get("WTF_CSRF_ENABLED"):
csrf = CSRFProtect(app)
csrf_exempt_list = conf.get('WTF_CSRF_EXEMPT_LIST', [])
csrf_exempt_list = conf.get("WTF_CSRF_EXEMPT_LIST", [])
for ex in csrf_exempt_list:
csrf.exempt(ex)
pessimistic_connection_handling(db.engine)
cache = setup_cache(app, conf.get('CACHE_CONFIG'))
tables_cache = setup_cache(app, conf.get('TABLE_NAMES_CACHE_CONFIG'))
cache = setup_cache(app, conf.get("CACHE_CONFIG"))
tables_cache = setup_cache(app, conf.get("TABLE_NAMES_CACHE_CONFIG"))
migrate = Migrate(app, db, directory=APP_DIR + '/migrations')
migrate = Migrate(app, db, directory=APP_DIR + "/migrations")
# Logging configuration
logging.basicConfig(format=app.config.get('LOG_FORMAT'))
logging.getLogger().setLevel(app.config.get('LOG_LEVEL'))
logging.basicConfig(format=app.config.get("LOG_FORMAT"))
logging.getLogger().setLevel(app.config.get("LOG_LEVEL"))
if app.config.get('ENABLE_TIME_ROTATE'):
logging.getLogger().setLevel(app.config.get('TIME_ROTATE_LOG_LEVEL'))
if app.config.get("ENABLE_TIME_ROTATE"):
logging.getLogger().setLevel(app.config.get("TIME_ROTATE_LOG_LEVEL"))
handler = TimedRotatingFileHandler(
app.config.get('FILENAME'),
when=app.config.get('ROLLOVER'),
interval=app.config.get('INTERVAL'),
backupCount=app.config.get('BACKUP_COUNT'))
app.config.get("FILENAME"),
when=app.config.get("ROLLOVER"),
interval=app.config.get("INTERVAL"),
backupCount=app.config.get("BACKUP_COUNT"),
)
logging.getLogger().addHandler(handler)
if app.config.get('ENABLE_CORS'):
if app.config.get("ENABLE_CORS"):
from flask_cors import CORS
CORS(app, **app.config.get('CORS_OPTIONS'))
if app.config.get('ENABLE_PROXY_FIX'):
CORS(app, **app.config.get("CORS_OPTIONS"))
if app.config.get("ENABLE_PROXY_FIX"):
app.wsgi_app = ProxyFix(app.wsgi_app)
if app.config.get('ENABLE_CHUNK_ENCODING'):
if app.config.get("ENABLE_CHUNK_ENCODING"):
class ChunkedEncodingFix(object):
def __init__(self, app):
@ -167,40 +169,41 @@ if app.config.get('ENABLE_CHUNK_ENCODING'):
def __call__(self, environ, start_response):
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
if environ.get('HTTP_TRANSFER_ENCODING', '').lower() == u'chunked':
environ['wsgi.input_terminated'] = True
if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == u"chunked":
environ["wsgi.input_terminated"] = True
return self.app(environ, start_response)
app.wsgi_app = ChunkedEncodingFix(app.wsgi_app)
if app.config.get('UPLOAD_FOLDER'):
if app.config.get("UPLOAD_FOLDER"):
try:
os.makedirs(app.config.get('UPLOAD_FOLDER'))
os.makedirs(app.config.get("UPLOAD_FOLDER"))
except OSError:
pass
for middleware in app.config.get('ADDITIONAL_MIDDLEWARE'):
for middleware in app.config.get("ADDITIONAL_MIDDLEWARE"):
app.wsgi_app = middleware(app.wsgi_app)
class MyIndexView(IndexView):
@expose('/')
@expose("/")
def index(self):
return redirect('/superset/welcome')
return redirect("/superset/welcome")
custom_sm = app.config.get('CUSTOM_SECURITY_MANAGER') or SupersetSecurityManager
custom_sm = app.config.get("CUSTOM_SECURITY_MANAGER") or SupersetSecurityManager
if not issubclass(custom_sm, SupersetSecurityManager):
raise Exception(
"""Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager,
not FAB's security manager.
See [4565] in UPDATING.md""")
See [4565] in UPDATING.md"""
)
with app.app_context():
appbuilder = AppBuilder(
app,
db.session,
base_template='superset/base.html',
base_template="superset/base.html",
indexview=MyIndexView,
security_manager_class=custom_sm,
update_perms=False, # Run `superset init` to update FAB's perms
@ -208,15 +211,15 @@ with app.app_context():
security_manager = appbuilder.sm
results_backend = app.config.get('RESULTS_BACKEND')
results_backend = app.config.get("RESULTS_BACKEND")
# Merge user defined feature flags with default feature flags
_feature_flags = app.config.get('DEFAULT_FEATURE_FLAGS') or {}
_feature_flags.update(app.config.get('FEATURE_FLAGS') or {})
_feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
_feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
def get_feature_flags():
GET_FEATURE_FLAGS_FUNC = app.config.get('GET_FEATURE_FLAGS_FUNC')
GET_FEATURE_FLAGS_FUNC = app.config.get("GET_FEATURE_FLAGS_FUNC")
if GET_FEATURE_FLAGS_FUNC:
return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags))
return _feature_flags
@ -228,22 +231,22 @@ def is_feature_enabled(feature):
# Flask-Compress
if conf.get('ENABLE_FLASK_COMPRESS'):
if conf.get("ENABLE_FLASK_COMPRESS"):
Compress(app)
if app.config['TALISMAN_ENABLED']:
talisman_config = app.config.get('TALISMAN_CONFIG')
if app.config["TALISMAN_ENABLED"]:
talisman_config = app.config.get("TALISMAN_CONFIG")
Talisman(app, **talisman_config)
# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = app.config.get('FLASK_APP_MUTATOR')
flask_app_mutator = app.config.get("FLASK_APP_MUTATOR")
if flask_app_mutator:
flask_app_mutator(app)
from superset import views # noqa
# Registering sources
module_datasource_map = app.config.get('DEFAULT_MODULE_DS_MAP')
module_datasource_map.update(app.config.get('ADDITIONAL_MODULE_DS_MAP'))
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
ConnectorRegistry.register_sources(module_datasource_map)

View File

@ -26,11 +26,8 @@ from colorama import Fore, Style
from pathlib2 import Path
import yaml
from superset import (
app, appbuilder, data, db, security_manager,
)
from superset.utils import (
core as utils, dashboard_import_export, dict_import_export)
from superset import app, appbuilder, data, db, security_manager
from superset.utils import core as utils, dashboard_import_export, dict_import_export
config = app.config
celery_app = utils.get_celery_app(config)
@ -54,114 +51,128 @@ def init():
@app.cli.command()
@click.option('--verbose', '-v', is_flag=True, help='Show extra information')
@click.option("--verbose", "-v", is_flag=True, help="Show extra information")
def version(verbose):
"""Prints the current version number"""
print(Fore.BLUE + '-=' * 15)
print(Fore.YELLOW + 'Superset ' + Fore.CYAN + '{version}'.format(
version=config.get('VERSION_STRING')))
print(Fore.BLUE + '-=' * 15)
print(Fore.BLUE + "-=" * 15)
print(
Fore.YELLOW
+ "Superset "
+ Fore.CYAN
+ "{version}".format(version=config.get("VERSION_STRING"))
)
print(Fore.BLUE + "-=" * 15)
if verbose:
print('[DB] : ' + '{}'.format(db.engine))
print("[DB] : " + "{}".format(db.engine))
print(Style.RESET_ALL)
def load_examples_run(load_test_data):
print('Loading examples into {}'.format(db))
print("Loading examples into {}".format(db))
data.load_css_templates()
print('Loading energy related dataset')
print("Loading energy related dataset")
data.load_energy()
print("Loading [World Bank's Health Nutrition and Population Stats]")
data.load_world_bank_health_n_pop()
print('Loading [Birth names]')
print("Loading [Birth names]")
data.load_birth_names()
print('Loading [Unicode test data]')
print("Loading [Unicode test data]")
data.load_unicode_test_data()
if not load_test_data:
print('Loading [Random time series data]')
print("Loading [Random time series data]")
data.load_random_time_series_data()
print('Loading [Random long/lat data]')
print("Loading [Random long/lat data]")
data.load_long_lat_data()
print('Loading [Country Map data]')
print("Loading [Country Map data]")
data.load_country_map_data()
print('Loading [Multiformat time series]')
print("Loading [Multiformat time series]")
data.load_multiformat_time_series()
print('Loading [Paris GeoJson]')
print("Loading [Paris GeoJson]")
data.load_paris_iris_geojson()
print('Loading [San Francisco population polygons]')
print("Loading [San Francisco population polygons]")
data.load_sf_population_polygons()
print('Loading [Flights data]')
print("Loading [Flights data]")
data.load_flights()
print('Loading [BART lines]')
print("Loading [BART lines]")
data.load_bart_lines()
print('Loading [Multi Line]')
print("Loading [Multi Line]")
data.load_multi_line()
print('Loading [Misc Charts] dashboard')
print("Loading [Misc Charts] dashboard")
data.load_misc_dashboard()
print('Loading DECK.gl demo')
print("Loading DECK.gl demo")
data.load_deck_dash()
print('Loading [Tabbed dashboard]')
print("Loading [Tabbed dashboard]")
data.load_tabbed_dashboard()
@app.cli.command()
@click.option('--load-test-data', '-t', is_flag=True, help='Load additional test data')
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """
load_examples_run(load_test_data)
@app.cli.command()
@click.option('--datasource', '-d', help='Specify which datasource name to load, if '
'omitted, all datasources will be refreshed')
@click.option('--merge', '-m', is_flag=True, default=False,
help="Specify using 'merge' property during operation. "
'Default value is False.')
@click.option(
"--datasource",
"-d",
help="Specify which datasource name to load, if "
"omitted, all datasources will be refreshed",
)
@click.option(
"--merge",
"-m",
is_flag=True,
default=False,
help="Specify using 'merge' property during operation. " "Default value is False.",
)
def refresh_druid(datasource, merge):
"""Refresh druid datasources"""
session = db.session()
from superset.connectors.druid.models import DruidCluster
for cluster in session.query(DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource,
merge_flag=merge)
cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
except Exception as e:
print(
"Error while processing cluster '{}'\n{}".format(
cluster, str(e)))
print("Error while processing cluster '{}'\n{}".format(cluster, str(e)))
logging.exception(e)
cluster.metadata_last_refreshed = datetime.now()
print(
'Refreshed metadata from cluster '
'[' + cluster.cluster_name + ']')
print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
session.commit()
@app.cli.command()
@click.option(
'--path', '-p',
help='Path to a single JSON file or path containing multiple JSON files'
'files to import (*.json)')
"--path",
"-p",
help="Path to a single JSON file or path containing multiple JSON files"
"files to import (*.json)",
)
@click.option(
'--recursive', '-r', is_flag=True, default=False,
help='recursively search the path for json files')
"--recursive",
"-r",
is_flag=True,
default=False,
help="recursively search the path for json files",
)
def import_dashboards(path, recursive):
"""Import dashboards from JSON"""
p = Path(path)
@ -169,114 +180,135 @@ def import_dashboards(path, recursive):
if p.is_file():
files.append(p)
elif p.exists() and not recursive:
files.extend(p.glob('*.json'))
files.extend(p.glob("*.json"))
elif p.exists() and recursive:
files.extend(p.rglob('*.json'))
files.extend(p.rglob("*.json"))
for f in files:
logging.info('Importing dashboard from file %s', f)
logging.info("Importing dashboard from file %s", f)
try:
with f.open() as data_stream:
dashboard_import_export.import_dashboards(
db.session, data_stream)
dashboard_import_export.import_dashboards(db.session, data_stream)
except Exception as e:
logging.error('Error when importing dashboard from file %s', f)
logging.error("Error when importing dashboard from file %s", f)
logging.error(e)
@app.cli.command()
@click.option(
'--dashboard-file', '-f', default=None,
help='Specify the the file to export to')
"--dashboard-file", "-f", default=None, help="Specify the the file to export to"
)
@click.option(
'--print_stdout', '-p', is_flag=True, default=False,
help='Print JSON to stdout')
"--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout"
)
def export_dashboards(print_stdout, dashboard_file):
"""Export dashboards to JSON"""
data = dashboard_import_export.export_dashboards(db.session)
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
logging.info('Exporting dashboards to %s', dashboard_file)
with open(dashboard_file, 'w') as data_stream:
logging.info("Exporting dashboards to %s", dashboard_file)
with open(dashboard_file, "w") as data_stream:
data_stream.write(data)
@app.cli.command()
@click.option(
'--path', '-p',
help='Path to a single YAML file or path containing multiple YAML '
'files to import (*.yaml or *.yml)')
"--path",
"-p",
help="Path to a single YAML file or path containing multiple YAML "
"files to import (*.yaml or *.yml)",
)
@click.option(
'--sync', '-s', 'sync', default='',
help='comma seperated list of element types to synchronize '
'e.g. "metrics,columns" deletes metrics and columns in the DB '
'that are not specified in the YAML file')
"--sync",
"-s",
"sync",
default="",
help="comma seperated list of element types to synchronize "
'e.g. "metrics,columns" deletes metrics and columns in the DB '
"that are not specified in the YAML file",
)
@click.option(
'--recursive', '-r', is_flag=True, default=False,
help='recursively search the path for yaml files')
"--recursive",
"-r",
is_flag=True,
default=False,
help="recursively search the path for yaml files",
)
def import_datasources(path, sync, recursive):
"""Import datasources from YAML"""
sync_array = sync.split(',')
sync_array = sync.split(",")
p = Path(path)
files = []
if p.is_file():
files.append(p)
elif p.exists() and not recursive:
files.extend(p.glob('*.yaml'))
files.extend(p.glob('*.yml'))
files.extend(p.glob("*.yaml"))
files.extend(p.glob("*.yml"))
elif p.exists() and recursive:
files.extend(p.rglob('*.yaml'))
files.extend(p.rglob('*.yml'))
files.extend(p.rglob("*.yaml"))
files.extend(p.rglob("*.yml"))
for f in files:
logging.info('Importing datasources from file %s', f)
logging.info("Importing datasources from file %s", f)
try:
with f.open() as data_stream:
dict_import_export.import_from_dict(
db.session,
yaml.safe_load(data_stream),
sync=sync_array)
db.session, yaml.safe_load(data_stream), sync=sync_array
)
except Exception as e:
logging.error('Error when importing datasources from file %s', f)
logging.error("Error when importing datasources from file %s", f)
logging.error(e)
@app.cli.command()
@click.option(
'--datasource-file', '-f', default=None,
help='Specify the the file to export to')
"--datasource-file", "-f", default=None, help="Specify the the file to export to"
)
@click.option(
'--print_stdout', '-p', is_flag=True, default=False,
help='Print YAML to stdout')
"--print_stdout", "-p", is_flag=True, default=False, help="Print YAML to stdout"
)
@click.option(
'--back-references', '-b', is_flag=True, default=False,
help='Include parent back references')
"--back-references",
"-b",
is_flag=True,
default=False,
help="Include parent back references",
)
@click.option(
'--include-defaults', '-d', is_flag=True, default=False,
help='Include fields containing defaults')
def export_datasources(print_stdout, datasource_file,
back_references, include_defaults):
"--include-defaults",
"-d",
is_flag=True,
default=False,
help="Include fields containing defaults",
)
def export_datasources(
print_stdout, datasource_file, back_references, include_defaults
):
"""Export datasources to YAML"""
data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults)
include_defaults=include_defaults,
)
if print_stdout or not datasource_file:
yaml.safe_dump(data, stdout, default_flow_style=False)
if datasource_file:
logging.info('Exporting datasources to %s', datasource_file)
with open(datasource_file, 'w') as data_stream:
logging.info("Exporting datasources to %s", datasource_file)
with open(datasource_file, "w") as data_stream:
yaml.safe_dump(data, data_stream, default_flow_style=False)
@app.cli.command()
@click.option(
'--back-references', '-b', is_flag=True, default=False,
help='Include parent back references')
"--back-references",
"-b",
is_flag=True,
default=False,
help="Include parent back references",
)
def export_datasource_schema(back_references):
"""Export datasource YAML schema to stdout"""
data = dict_import_export.export_schema_to_dict(
back_references=back_references)
data = dict_import_export.export_schema_to_dict(back_references=back_references)
yaml.safe_dump(data, stdout, default_flow_style=False)
@ -284,47 +316,49 @@ def export_datasource_schema(back_references):
def update_datasources_cache():
"""Refresh sqllab datasources cache"""
from superset.models.core import Database
for database in db.session.query(Database).all():
if database.allow_multi_schema_metadata_fetch:
print('Fetching {} datasources ...'.format(database.name))
print("Fetching {} datasources ...".format(database.name))
try:
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
force=True, cache=True, cache_timeout=24 * 60 * 60
)
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
force=True, cache=True, cache_timeout=24 * 60 * 60
)
except Exception as e:
print('{}'.format(str(e)))
print("{}".format(str(e)))
@app.cli.command()
@click.option(
'--workers', '-w',
type=int,
help='Number of celery server workers to fire up')
"--workers", "-w", type=int, help="Number of celery server workers to fire up"
)
def worker(workers):
"""Starts a Superset worker for async SQL query execution."""
logging.info(
"The 'superset worker' command is deprecated. Please use the 'celery "
"worker' command instead.")
"worker' command instead."
)
if workers:
celery_app.conf.update(CELERYD_CONCURRENCY=workers)
elif config.get('SUPERSET_CELERY_WORKERS'):
elif config.get("SUPERSET_CELERY_WORKERS"):
celery_app.conf.update(
CELERYD_CONCURRENCY=config.get('SUPERSET_CELERY_WORKERS'))
CELERYD_CONCURRENCY=config.get("SUPERSET_CELERY_WORKERS")
)
worker = celery_app.Worker(optimization='fair')
worker = celery_app.Worker(optimization="fair")
worker.start()
@app.cli.command()
@click.option(
'-p', '--port',
default='5555',
help='Port on which to start the Flower process')
"-p", "--port", default="5555", help="Port on which to start the Flower process"
)
@click.option(
'-a', '--address',
default='localhost',
help='Address on which to run the service')
"-a", "--address", default="localhost", help="Address on which to run the service"
)
def flower(port, address):
"""Runs a Celery Flower web server
@ -332,18 +366,19 @@ def flower(port, address):
broker"""
BROKER_URL = celery_app.conf.BROKER_URL
cmd = (
'celery flower '
f'--broker={BROKER_URL} '
f'--port={port} '
f'--address={address} '
"celery flower "
f"--broker={BROKER_URL} "
f"--port={port} "
f"--address={address} "
)
logging.info(
"The 'superset flower' command is deprecated. Please use the 'celery "
"flower' command instead.")
print(Fore.GREEN + 'Starting a Celery Flower instance')
print(Fore.BLUE + '-=' * 40)
"flower' command instead."
)
print(Fore.GREEN + "Starting a Celery Flower instance")
print(Fore.BLUE + "-=" * 40)
print(Fore.YELLOW + cmd)
print(Fore.BLUE + '-=' * 40)
print(Fore.BLUE + "-=" * 40)
Popen(cmd, shell=True).wait()
@ -354,7 +389,7 @@ def load_test_users():
Syncs permissions for those users/roles
"""
print(Fore.GREEN + 'Loading a set of users for unit tests')
print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run()
@ -364,51 +399,73 @@ def load_test_users_run():
Syncs permissions for those users/roles
"""
if config.get('TESTING'):
if config.get("TESTING"):
security_manager.sync_role_definitions()
gamma_sqllab_role = security_manager.add_role('gamma_sqllab')
for perm in security_manager.find_role('Gamma').permissions:
gamma_sqllab_role = security_manager.add_role("gamma_sqllab")
for perm in security_manager.find_role("Gamma").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = utils.get_main_database(security_manager.get_session).perm
security_manager.add_permission_view_menu('database_access', db_perm)
security_manager.add_permission_view_menu("database_access", db_perm)
db_pvm = security_manager.find_permission_view_menu(
view_menu_name=db_perm, permission_name='database_access')
view_menu_name=db_perm, permission_name="database_access"
)
gamma_sqllab_role.permissions.append(db_pvm)
for perm in security_manager.find_role('sql_lab').permissions:
for perm in security_manager.find_role("sql_lab").permissions:
security_manager.add_permission_role(gamma_sqllab_role, perm)
admin = security_manager.find_user('admin')
admin = security_manager.find_user("admin")
if not admin:
security_manager.add_user(
'admin', 'admin', ' user', 'admin@fab.org',
security_manager.find_role('Admin'),
password='general')
"admin",
"admin",
" user",
"admin@fab.org",
security_manager.find_role("Admin"),
password="general",
)
gamma = security_manager.find_user('gamma')
gamma = security_manager.find_user("gamma")
if not gamma:
security_manager.add_user(
'gamma', 'gamma', 'user', 'gamma@fab.org',
security_manager.find_role('Gamma'),
password='general')
"gamma",
"gamma",
"user",
"gamma@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma2 = security_manager.find_user('gamma2')
gamma2 = security_manager.find_user("gamma2")
if not gamma2:
security_manager.add_user(
'gamma2', 'gamma2', 'user', 'gamma2@fab.org',
security_manager.find_role('Gamma'),
password='general')
"gamma2",
"gamma2",
"user",
"gamma2@fab.org",
security_manager.find_role("Gamma"),
password="general",
)
gamma_sqllab_user = security_manager.find_user('gamma_sqllab')
gamma_sqllab_user = security_manager.find_user("gamma_sqllab")
if not gamma_sqllab_user:
security_manager.add_user(
'gamma_sqllab', 'gamma_sqllab', 'user', 'gamma_sqllab@fab.org',
gamma_sqllab_role, password='general')
"gamma_sqllab",
"gamma_sqllab",
"user",
"gamma_sqllab@fab.org",
gamma_sqllab_role,
password="general",
)
alpha = security_manager.find_user('alpha')
alpha = security_manager.find_user("alpha")
if not alpha:
security_manager.add_user(
'alpha', 'alpha', 'user', 'alpha@fab.org',
security_manager.find_role('Alpha'),
password='general')
"alpha",
"alpha",
"user",
"alpha@fab.org",
security_manager.find_role("Alpha"),
password="general",
)
security_manager.get_session.commit()

View File

@ -32,7 +32,7 @@ from superset.utils.core import DTTM_ALIAS
from .query_object import QueryObject
config = app.config
stats_logger = config.get('STATS_LOGGER')
stats_logger = config.get("STATS_LOGGER")
class QueryContext:
@ -41,21 +41,21 @@ class QueryContext:
to retrieve the data payload for a given viz.
"""
cache_type = 'df'
cache_type = "df"
enforce_numerical_metrics = True
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
self,
datasource: Dict,
queries: List[Dict],
force: bool = False,
custom_cache_timeout: int = None,
self,
datasource: Dict,
queries: List[Dict],
force: bool = False,
custom_cache_timeout: int = None,
):
self.datasource = ConnectorRegistry.get_datasource(datasource.get('type'),
int(datasource.get('id')), # noqa: E501, T400
db.session)
self.datasource = ConnectorRegistry.get_datasource(
datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400
)
self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
self.force = force
@ -72,7 +72,7 @@ class QueryContext:
# support multiple queries from different data source.
timestamp_format = None
if self.datasource.type == 'table':
if self.datasource.type == "table":
dttm_col = self.datasource.get_col(query_object.granularity)
if dttm_col:
timestamp_format = dttm_col.python_date_format
@ -88,12 +88,13 @@ class QueryContext:
# parsing logic
if df is not None and not df.empty:
if DTTM_ALIAS in df.columns:
if timestamp_format in ('epoch_s', 'epoch_ms'):
if timestamp_format in ("epoch_s", "epoch_ms"):
# Column has already been formatted as a timestamp.
df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(pd.Timestamp)
else:
df[DTTM_ALIAS] = pd.to_datetime(
df[DTTM_ALIAS], utc=False, format=timestamp_format)
df[DTTM_ALIAS], utc=False, format=timestamp_format
)
if self.datasource.offset:
df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset)
df[DTTM_ALIAS] += query_object.time_shift
@ -103,10 +104,10 @@ class QueryContext:
df.replace([np.inf, -np.inf], np.nan)
return {
'query': result.query,
'status': result.status,
'error_message': result.error_message,
'df': df,
"query": result.query,
"status": result.status,
"error_message": result.error_message,
"df": df,
}
def df_metrics_to_num(self, df, query_object):
@ -114,23 +115,23 @@ class QueryContext:
metrics = [metric for metric in query_object.metrics]
for col, dtype in df.dtypes.items():
if dtype.type == np.object_ and col in metrics:
df[col] = pd.to_numeric(df[col], errors='coerce')
df[col] = pd.to_numeric(df[col], errors="coerce")
def get_data(self, df):
return df.to_dict(orient='records')
return df.to_dict(orient="records")
def get_single_payload(self, query_obj):
"""Returns a payload of metadata and data"""
payload = self.get_df_payload(query_obj)
df = payload.get('df')
status = payload.get('status')
df = payload.get("df")
status = payload.get("status")
if status != utils.QueryStatus.FAILED:
if df is not None and df.empty:
payload['error'] = 'No data'
payload["error"] = "No data"
else:
payload['data'] = self.get_data(df)
if 'df' in payload:
del payload['df']
payload["data"] = self.get_data(df)
if "df" in payload:
del payload["df"]
return payload
def get_payload(self):
@ -144,94 +145,94 @@ class QueryContext:
if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout
if (
hasattr(self.datasource, 'database') and
self.datasource.database.cache_timeout) is not None:
hasattr(self.datasource, "database")
and self.datasource.database.cache_timeout
) is not None:
return self.datasource.database.cache_timeout
return config.get('CACHE_DEFAULT_TIMEOUT')
return config.get("CACHE_DEFAULT_TIMEOUT")
def get_df_payload(self, query_obj, **kwargs):
"""Handles caching around the df paylod retrieval"""
cache_key = query_obj.cache_key(
datasource=self.datasource.uid, **kwargs) if query_obj else None
logging.info('Cache key: {}'.format(cache_key))
cache_key = (
query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
if query_obj
else None
)
logging.info("Cache key: {}".format(cache_key))
is_loaded = False
stacktrace = None
df = None
cached_dttm = datetime.utcnow().isoformat().split('.')[0]
cached_dttm = datetime.utcnow().isoformat().split(".")[0]
cache_value = None
status = None
query = ''
query = ""
error_message = None
if cache_key and cache and not self.force:
cache_value = cache.get(cache_key)
if cache_value:
stats_logger.incr('loaded_from_cache')
stats_logger.incr("loaded_from_cache")
try:
cache_value = pkl.loads(cache_value)
df = cache_value['df']
query = cache_value['query']
df = cache_value["df"]
query = cache_value["query"]
status = utils.QueryStatus.SUCCESS
is_loaded = True
except Exception as e:
logging.exception(e)
logging.error('Error reading cache: ' +
utils.error_msg_from_exception(e))
logging.info('Serving from cache')
logging.error(
"Error reading cache: " + utils.error_msg_from_exception(e)
)
logging.info("Serving from cache")
if query_obj and not is_loaded:
try:
query_result = self.get_query_result(query_obj)
status = query_result['status']
query = query_result['query']
error_message = query_result['error_message']
df = query_result['df']
status = query_result["status"]
query = query_result["query"]
error_message = query_result["error_message"]
df = query_result["df"]
if status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
stats_logger.incr("loaded_from_source")
is_loaded = True
except Exception as e:
logging.exception(e)
if not error_message:
error_message = '{}'.format(e)
error_message = "{}".format(e)
status = utils.QueryStatus.FAILED
stacktrace = traceback.format_exc()
if (
is_loaded and
cache_key and
cache and
status != utils.QueryStatus.FAILED):
if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED:
try:
cache_value = dict(
dttm=cached_dttm,
df=df if df is not None else None,
query=query,
dttm=cached_dttm, df=df if df is not None else None, query=query
)
cache_binary = pkl.dumps(
cache_value, protocol=pkl.HIGHEST_PROTOCOL)
cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info('Caching {} chars at key {}'.format(
len(cache_binary), cache_key))
logging.info(
"Caching {} chars at key {}".format(
len(cache_binary), cache_key
)
)
stats_logger.incr('set_cache_key')
stats_logger.incr("set_cache_key")
cache.set(
cache_key,
cache_value=cache_binary,
timeout=self.cache_timeout)
cache_key, cache_value=cache_binary, timeout=self.cache_timeout
)
except Exception as e:
# cache.set call can fail if the backend is down or if
# the key is too large or whatever other reasons
logging.warning('Could not cache key {}'.format(cache_key))
logging.warning("Could not cache key {}".format(cache_key))
logging.exception(e)
cache.delete(cache_key)
return {
'cache_key': cache_key,
'cached_dttm': cache_value['dttm'] if cache_value is not None else None,
'cache_timeout': self.cache_timeout,
'df': df,
'error': error_message,
'is_cached': cache_key is not None,
'query': query,
'status': status,
'stacktrace': stacktrace,
'rowcount': len(df.index) if df is not None else 0,
"cache_key": cache_key,
"cached_dttm": cache_value["dttm"] if cache_value is not None else None,
"cache_timeout": self.cache_timeout,
"df": df,
"error": error_message,
"is_cached": cache_key is not None,
"query": query,
"status": status,
"stacktrace": stacktrace,
"rowcount": len(df.index) if df is not None else 0,
}

View File

@ -27,6 +27,7 @@ from superset.utils import core as utils
# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
# https://github.com/python/mypy/issues/5288
class QueryObject:
"""
The query object's schema matches the interfaces of DB connectors like sqla
@ -34,25 +35,25 @@ class QueryObject:
"""
def __init__(
self,
granularity: str,
metrics: List[Union[Dict, str]],
groupby: List[str] = None,
filters: List[str] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
timeseries_limit: int = 0,
row_limit: int = app.config.get('ROW_LIMIT'),
timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
prequeries: Optional[List[Dict]] = None,
is_prequery: bool = False,
columns: List[str] = None,
orderby: List[List] = None,
relative_start: str = app.config.get('DEFAULT_RELATIVE_START_TIME', 'today'),
relative_end: str = app.config.get('DEFAULT_RELATIVE_END_TIME', 'today'),
self,
granularity: str,
metrics: List[Union[Dict, str]],
groupby: List[str] = None,
filters: List[str] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
timeseries_limit: int = 0,
row_limit: int = app.config.get("ROW_LIMIT"),
timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
prequeries: Optional[List[Dict]] = None,
is_prequery: bool = False,
columns: List[str] = None,
orderby: List[List] = None,
relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"),
):
self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until(
@ -69,7 +70,7 @@ class QueryObject:
# Temporal solution for backward compatability issue
# due the new format of non-ad-hoc metric.
self.metrics = [
metric if 'expressionType' in metric else metric['label'] # noqa: T484
metric if "expressionType" in metric else metric["label"] # noqa: T484
for metric in metrics
]
self.row_limit = row_limit
@ -85,22 +86,22 @@ class QueryObject:
def to_dict(self):
query_object_dict = {
'granularity': self.granularity,
'from_dttm': self.from_dttm,
'to_dttm': self.to_dttm,
'is_timeseries': self.is_timeseries,
'groupby': self.groupby,
'metrics': self.metrics,
'row_limit': self.row_limit,
'filter': self.filter,
'timeseries_limit': self.timeseries_limit,
'timeseries_limit_metric': self.timeseries_limit_metric,
'order_desc': self.order_desc,
'prequeries': self.prequeries,
'is_prequery': self.is_prequery,
'extras': self.extras,
'columns': self.columns,
'orderby': self.orderby,
"granularity": self.granularity,
"from_dttm": self.from_dttm,
"to_dttm": self.to_dttm,
"is_timeseries": self.is_timeseries,
"groupby": self.groupby,
"metrics": self.metrics,
"row_limit": self.row_limit,
"filter": self.filter,
"timeseries_limit": self.timeseries_limit,
"timeseries_limit_metric": self.timeseries_limit_metric,
"order_desc": self.order_desc,
"prequeries": self.prequeries,
"is_prequery": self.is_prequery,
"extras": self.extras,
"columns": self.columns,
"orderby": self.orderby,
}
return query_object_dict
@ -115,17 +116,14 @@ class QueryObject:
cache_dict = self.to_dict()
cache_dict.update(extra)
for k in ['from_dttm', 'to_dttm']:
for k in ["from_dttm", "to_dttm"]:
del cache_dict[k]
if self.time_range:
cache_dict['time_range'] = self.time_range
cache_dict["time_range"] = self.time_range
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode('utf-8')).hexdigest()
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def json_dumps(self, obj, sort_keys=False):
return json.dumps(
obj,
default=utils.json_int_dttm_ser,
ignore_nan=True,
sort_keys=sort_keys,
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)

View File

@ -37,18 +37,18 @@ from superset.stats_logger import DummyStatsLogger
STATS_LOGGER = DummyStatsLogger()
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
if 'SUPERSET_HOME' in os.environ:
DATA_DIR = os.environ['SUPERSET_HOME']
if "SUPERSET_HOME" in os.environ:
DATA_DIR = os.environ["SUPERSET_HOME"]
else:
DATA_DIR = os.path.join(os.path.expanduser('~'), '.superset')
DATA_DIR = os.path.join(os.path.expanduser("~"), ".superset")
# ---------------------------------------------------------
# Superset specific config
# ---------------------------------------------------------
PACKAGE_DIR = os.path.join(BASE_DIR, 'static', 'assets')
PACKAGE_FILE = os.path.join(PACKAGE_DIR, 'package.json')
PACKAGE_DIR = os.path.join(BASE_DIR, "static", "assets")
PACKAGE_FILE = os.path.join(PACKAGE_DIR, "package.json")
with open(PACKAGE_FILE) as package_file:
VERSION_STRING = json.load(package_file)['version']
VERSION_STRING = json.load(package_file)["version"]
ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000
@ -57,7 +57,7 @@ FILTER_SELECT_ROW_LIMIT = 10000
SUPERSET_WORKERS = 2 # deprecated
SUPERSET_CELERY_WORKERS = 32 # deprecated
SUPERSET_WEBSERVER_ADDRESS = '0.0.0.0'
SUPERSET_WEBSERVER_ADDRESS = "0.0.0.0"
SUPERSET_WEBSERVER_PORT = 8088
# This is an important setting, and should be lower than your
@ -73,10 +73,10 @@ SQLALCHEMY_TRACK_MODIFICATIONS = False
# ---------------------------------------------------------
# Your App secret key
SECRET_KEY = '\2\1thisismyscretkey\1\2\e\y\y\h' # noqa
SECRET_KEY = "\2\1thisismyscretkey\1\2\e\y\y\h" # noqa
# The SQLAlchemy connection string.
SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(DATA_DIR, 'superset.db')
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "superset.db")
# SQLALCHEMY_DATABASE_URI = 'mysql://myapp@localhost/myapp'
# SQLALCHEMY_DATABASE_URI = 'postgresql://root:password@localhost/myapp'
@ -96,10 +96,10 @@ QUERY_SEARCH_LIMIT = 1000
WTF_CSRF_ENABLED = True
# Add endpoints that need to be exempt from CSRF protection
WTF_CSRF_EXEMPT_LIST = ['superset.views.core.log']
WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log"]
# Whether to run the web server in debug mode or not
DEBUG = os.environ.get('FLASK_ENV') == 'development'
DEBUG = os.environ.get("FLASK_ENV") == "development"
FLASK_USE_RELOAD = True
# Whether to show the stacktrace on 500 error
@ -112,10 +112,10 @@ ENABLE_PROXY_FIX = False
# GLOBALS FOR APP Builder
# ------------------------------
# Uncomment to setup Your App name
APP_NAME = 'Superset'
APP_NAME = "Superset"
# Uncomment to setup an App icon
APP_ICON = '/static/assets/images/superset-logo@2x.png'
APP_ICON = "/static/assets/images/superset-logo@2x.png"
APP_ICON_WIDTH = 126
# Uncomment to specify where clicking the logo would take the user
@ -131,7 +131,7 @@ LOGO_TARGET_PATH = None
# other tz can be overridden by providing a local_config
DRUID_IS_ACTIVE = True
DRUID_TZ = tz.tzutc()
DRUID_ANALYSIS_TYPES = ['cardinality']
DRUID_ANALYSIS_TYPES = ["cardinality"]
# ----------------------------------------------------
# AUTHENTICATION CONFIG
@ -175,21 +175,21 @@ PUBLIC_ROLE_LIKE_GAMMA = False
# Babel config for translations
# ---------------------------------------------------
# Setup default language
BABEL_DEFAULT_LOCALE = 'en'
BABEL_DEFAULT_LOCALE = "en"
# Your application default translation path
BABEL_DEFAULT_FOLDER = 'superset/translations'
BABEL_DEFAULT_FOLDER = "superset/translations"
# The allowed translation for you app
LANGUAGES = {
'en': {'flag': 'us', 'name': 'English'},
'it': {'flag': 'it', 'name': 'Italian'},
'fr': {'flag': 'fr', 'name': 'French'},
'zh': {'flag': 'cn', 'name': 'Chinese'},
'ja': {'flag': 'jp', 'name': 'Japanese'},
'de': {'flag': 'de', 'name': 'German'},
'pt': {'flag': 'pt', 'name': 'Portuguese'},
'pt_BR': {'flag': 'br', 'name': 'Brazilian Portuguese'},
'ru': {'flag': 'ru', 'name': 'Russian'},
'ko': {'flag': 'kr', 'name': 'Korean'},
"en": {"flag": "us", "name": "English"},
"it": {"flag": "it", "name": "Italian"},
"fr": {"flag": "fr", "name": "French"},
"zh": {"flag": "cn", "name": "Chinese"},
"ja": {"flag": "jp", "name": "Japanese"},
"de": {"flag": "de", "name": "German"},
"pt": {"flag": "pt", "name": "Portuguese"},
"pt_BR": {"flag": "br", "name": "Brazilian Portuguese"},
"ru": {"flag": "ru", "name": "Russian"},
"ko": {"flag": "kr", "name": "Korean"},
}
# ---------------------------------------------------
@ -202,7 +202,7 @@ LANGUAGES = {
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
DEFAULT_FEATURE_FLAGS = {
# Experimental feature introducing a client (browser) cache
'CLIENT_CACHE': False,
"CLIENT_CACHE": False
}
# A function that receives a dict of all feature flags
@ -225,19 +225,19 @@ GET_FEATURE_FLAGS_FUNC = None
# Image and file configuration
# ---------------------------------------------------
# The file upload folder, when using models with files
UPLOAD_FOLDER = BASE_DIR + '/app/static/uploads/'
UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/"
# The image upload folder, when using models with images
IMG_UPLOAD_FOLDER = BASE_DIR + '/app/static/uploads/'
IMG_UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/"
# The image upload url, when using models with images
IMG_UPLOAD_URL = '/static/uploads/'
IMG_UPLOAD_URL = "/static/uploads/"
# Setup image size default is (300, 200, True)
# IMG_SIZE = (300, 200, True)
CACHE_DEFAULT_TIMEOUT = 60 * 60 * 24
CACHE_CONFIG = {'CACHE_TYPE': 'null'}
TABLE_NAMES_CACHE_CONFIG = {'CACHE_TYPE': 'null'}
CACHE_CONFIG = {"CACHE_TYPE": "null"}
TABLE_NAMES_CACHE_CONFIG = {"CACHE_TYPE": "null"}
# CORS Options
ENABLE_CORS = False
@ -252,13 +252,12 @@ SUPERSET_WEBSERVER_DOMAINS = None
# Allowed format types for upload on Database view
# TODO: Add processing of other spreadsheet formats (xls, xlsx etc)
ALLOWED_EXTENSIONS = set(['csv'])
ALLOWED_EXTENSIONS = set(["csv"])
# CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv method
# CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv
# method.
# note: index option should not be overridden
CSV_EXPORT = {
'encoding': 'utf-8',
}
CSV_EXPORT = {"encoding": "utf-8"}
# ---------------------------------------------------
# Time grain configurations
@ -301,10 +300,12 @@ DRUID_DATA_SOURCE_BLACKLIST = []
# --------------------------------------------------
# Modules, datasources and middleware to be registered
# --------------------------------------------------
DEFAULT_MODULE_DS_MAP = OrderedDict([
('superset.connectors.sqla.models', ['SqlaTable']),
('superset.connectors.druid.models', ['DruidDatasource']),
])
DEFAULT_MODULE_DS_MAP = OrderedDict(
[
("superset.connectors.sqla.models", ["SqlaTable"]),
("superset.connectors.druid.models", ["DruidDatasource"]),
]
)
ADDITIONAL_MODULE_DS_MAP = {}
ADDITIONAL_MIDDLEWARE = []
@ -315,8 +316,8 @@ ADDITIONAL_MIDDLEWARE = []
# Console Log Settings
LOG_FORMAT = '%(asctime)s:%(levelname)s:%(name)s:%(message)s'
LOG_LEVEL = 'DEBUG'
LOG_FORMAT = "%(asctime)s:%(levelname)s:%(name)s:%(message)s"
LOG_LEVEL = "DEBUG"
# ---------------------------------------------------
# Enable Time Rotate Log Handler
@ -324,9 +325,9 @@ LOG_LEVEL = 'DEBUG'
# LOG_LEVEL = DEBUG, INFO, WARNING, ERROR, CRITICAL
ENABLE_TIME_ROTATE = False
TIME_ROTATE_LOG_LEVEL = 'DEBUG'
FILENAME = os.path.join(DATA_DIR, 'superset.log')
ROLLOVER = 'midnight'
TIME_ROTATE_LOG_LEVEL = "DEBUG"
FILENAME = os.path.join(DATA_DIR, "superset.log")
ROLLOVER = "midnight"
INTERVAL = 1
BACKUP_COUNT = 30
@ -344,7 +345,7 @@ BACKUP_COUNT = 30
# pass
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get('MAPBOX_API_KEY', '')
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
# Maximum number of rows returned from a database
# in async mode, no more than SQL_MAX_ROW will be returned and stored
@ -378,31 +379,26 @@ WARNING_MSG = None
class CeleryConfig(object):
BROKER_URL = 'sqla+sqlite:///celerydb.sqlite'
CELERY_IMPORTS = (
'superset.sql_lab',
'superset.tasks',
)
CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite'
CELERYD_LOG_LEVEL = 'DEBUG'
BROKER_URL = "sqla+sqlite:///celerydb.sqlite"
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks")
CELERY_RESULT_BACKEND = "db+sqlite:///celery_results.sqlite"
CELERYD_LOG_LEVEL = "DEBUG"
CELERYD_PREFETCH_MULTIPLIER = 1
CELERY_ACKS_LATE = True
CELERY_ANNOTATIONS = {
'sql_lab.get_sql_results': {
'rate_limit': '100/s',
},
'email_reports.send': {
'rate_limit': '1/s',
'time_limit': 120,
'soft_time_limit': 150,
'ignore_result': True,
"sql_lab.get_sql_results": {"rate_limit": "100/s"},
"email_reports.send": {
"rate_limit": "1/s",
"time_limit": 120,
"soft_time_limit": 150,
"ignore_result": True,
},
}
CELERYBEAT_SCHEDULE = {
'email_reports.schedule_hourly': {
'task': 'email_reports.schedule_hourly',
'schedule': crontab(minute=1, hour='*'),
},
"email_reports.schedule_hourly": {
"task": "email_reports.schedule_hourly",
"schedule": crontab(minute=1, hour="*"),
}
}
@ -444,7 +440,7 @@ CSV_TO_HIVE_UPLOAD_S3_BUCKET = None
# The directory within the bucket specified above that will
# contain all the external tables
CSV_TO_HIVE_UPLOAD_DIRECTORY = 'EXTERNAL_HIVE_TABLES/'
CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
@ -458,9 +454,9 @@ JINJA_CONTEXT_ADDONS = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.
ROBOT_PERMISSION_ROLES = ['Public', 'Gamma', 'Alpha', 'Admin', 'sql_lab']
ROBOT_PERMISSION_ROLES = ["Public", "Gamma", "Alpha", "Admin", "sql_lab"]
CONFIG_PATH_ENV_VAR = 'SUPERSET_CONFIG_PATH'
CONFIG_PATH_ENV_VAR = "SUPERSET_CONFIG_PATH"
# If a callable is specified, it will be called at app startup while passing
# a reference to the Flask app. This can be used to alter the Flask app
@ -474,16 +470,16 @@ ENABLE_ACCESS_REQUEST = False
# smtp server configuration
EMAIL_NOTIFICATIONS = False # all the emails are sent using dryrun
SMTP_HOST = 'localhost'
SMTP_HOST = "localhost"
SMTP_STARTTLS = True
SMTP_SSL = False
SMTP_USER = 'superset'
SMTP_USER = "superset"
SMTP_PORT = 25
SMTP_PASSWORD = 'superset'
SMTP_MAIL_FROM = 'superset@superset.com'
SMTP_PASSWORD = "superset"
SMTP_MAIL_FROM = "superset@superset.com"
if not CACHE_DEFAULT_TIMEOUT:
CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT')
CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get("CACHE_DEFAULT_TIMEOUT")
# Whether to bump the logging level to ERROR on the flask_appbuilder package
# Set to False if/when debugging FAB related issues like
@ -492,14 +488,14 @@ SILENCE_FAB = True
# The link to a page containing common errors and their resolutions
# It will be appended at the bottom of sql_lab errors.
TROUBLESHOOTING_LINK = ''
TROUBLESHOOTING_LINK = ""
# CSRF token timeout, set to None for a token that never expires
WTF_CSRF_TIME_LIMIT = 60 * 60 * 24 * 7
# This link should lead to a page with instructions on how to gain access to a
# Datasource. It will be placed at the bottom of permissions errors.
PERMISSION_INSTRUCTIONS_LINK = ''
PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
@ -565,7 +561,7 @@ EMAIL_REPORTS_CRON_RESOLUTION = 15
# Email report configuration
# From address in emails
EMAIL_REPORT_FROM_ADDRESS = 'reports@superset.org'
EMAIL_REPORT_FROM_ADDRESS = "reports@superset.org"
# Send bcc of all reports to this address. Set to None to disable.
# This is useful for maintaining an audit trail of all email deliveries.
@ -575,8 +571,8 @@ EMAIL_REPORT_BCC_ADDRESS = None
# This user should have permissions to browse all the dashboards and
# slices.
# TODO: In the future, login as the owner of the item to generate reports
EMAIL_REPORTS_USER = 'admin'
EMAIL_REPORTS_SUBJECT_PREFIX = '[Report] '
EMAIL_REPORTS_USER = "admin"
EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
# The webdriver to use for generating reports. Use one of the following
# firefox
@ -585,19 +581,16 @@ EMAIL_REPORTS_SUBJECT_PREFIX = '[Report] '
# chrome:
# Requires: headless chrome
# Limitations: unable to generate screenshots of elements
EMAIL_REPORTS_WEBDRIVER = 'firefox'
EMAIL_REPORTS_WEBDRIVER = "firefox"
# Window size - this will impact the rendering of the data
WEBDRIVER_WINDOW = {
'dashboard': (1600, 2000),
'slice': (3000, 1200),
}
WEBDRIVER_WINDOW = {"dashboard": (1600, 2000), "slice": (3000, 1200)}
# Any config options to be passed as-is to the webdriver
WEBDRIVER_CONFIGURATION = {}
# The base URL to query for accessing the user interface
WEBDRIVER_BASEURL = 'http://0.0.0.0:8080/'
WEBDRIVER_BASEURL = "http://0.0.0.0:8080/"
# Send user to a link where they can report bugs
BUG_REPORT_URL = None
@ -611,33 +604,34 @@ DOCUMENTATION_URL = None
# filter a moving window. By only setting the end time to now,
# start time will be set to midnight, while end will be relative to
# the query issue time.
DEFAULT_RELATIVE_START_TIME = 'today'
DEFAULT_RELATIVE_END_TIME = 'today'
DEFAULT_RELATIVE_START_TIME = "today"
DEFAULT_RELATIVE_END_TIME = "today"
# Configure which SQL validator to use for each engine
SQL_VALIDATORS_BY_ENGINE = {
'presto': 'PrestoDBSQLValidator',
}
SQL_VALIDATORS_BY_ENGINE = {"presto": "PrestoDBSQLValidator"}
# Do you want Talisman enabled?
TALISMAN_ENABLED = False
# If you want Talisman, how do you want it configured??
TALISMAN_CONFIG = {
'content_security_policy': None,
'force_https': True,
'force_https_permanent': False,
"content_security_policy": None,
"force_https": True,
"force_https_permanent": False,
}
try:
if CONFIG_PATH_ENV_VAR in os.environ:
# Explicitly import config module that is not in pythonpath; useful
# for case where app is being executed via pex.
print('Loaded your LOCAL configuration at [{}]'.format(
os.environ[CONFIG_PATH_ENV_VAR]))
print(
"Loaded your LOCAL configuration at [{}]".format(
os.environ[CONFIG_PATH_ENV_VAR]
)
)
module = sys.modules[__name__]
override_conf = imp.load_source(
'superset_config',
os.environ[CONFIG_PATH_ENV_VAR])
"superset_config", os.environ[CONFIG_PATH_ENV_VAR]
)
for key in dir(override_conf):
if key.isupper():
setattr(module, key, getattr(override_conf, key))
@ -645,7 +639,9 @@ try:
else:
from superset_config import * # noqa
import superset_config
print('Loaded your LOCAL configuration at [{}]'.format(
superset_config.__file__))
print(
"Loaded your LOCAL configuration at [{}]".format(superset_config.__file__)
)
except ImportError:
pass

View File

@ -17,9 +17,7 @@
# pylint: disable=C,R,W
import json
from sqlalchemy import (
and_, Boolean, Column, Integer, String, Text,
)
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, relationship
@ -67,7 +65,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@declared_attr
def slices(self):
return relationship(
'Slice',
"Slice",
primaryjoin=lambda: and_(
foreign(Slice.datasource_id) == self.id,
foreign(Slice.datasource_type) == self.type,
@ -82,11 +80,11 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property
def uid(self):
"""Unique id across datasource types"""
return f'{self.id}__{self.type}'
return f"{self.id}__{self.type}"
@property
def column_names(self):
return sorted([c.column_name for c in self.columns], key=lambda x: x or '')
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
def columns_types(self):
@ -94,7 +92,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property
def main_dttm_col(self):
return 'timestamp'
return "timestamp"
@property
def datasource_name(self):
@ -120,22 +118,18 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
@property
def url(self):
return '/{}/edit/{}'.format(self.baselink, self.id)
return "/{}/edit/{}".format(self.baselink, self.id)
@property
def explore_url(self):
if self.default_endpoint:
return self.default_endpoint
else:
return '/superset/explore/{obj.type}/{obj.id}/'.format(obj=self)
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property
def column_formats(self):
return {
m.metric_name: m.d3format
for m in self.metrics
if m.d3format
}
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
def add_missing_metrics(self, metrics):
exisiting_metrics = {m.metric_name for m in self.metrics}
@ -148,14 +142,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
def short_data(self):
"""Data representation of the datasource sent to the frontend"""
return {
'edit_url': self.url,
'id': self.id,
'uid': self.uid,
'schema': self.schema,
'name': self.name,
'type': self.type,
'connection': self.connection,
'creator': str(self.created_by),
"edit_url": self.url,
"id": self.id,
"uid": self.uid,
"schema": self.schema,
"name": self.name,
"type": self.type,
"connection": self.connection,
"creator": str(self.created_by),
}
@property
@ -168,68 +162,65 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
order_by_choices = []
# self.column_names return sorted column_names
for s in self.column_names:
s = str(s or '')
order_by_choices.append((json.dumps([s, True]), s + ' [asc]'))
order_by_choices.append((json.dumps([s, False]), s + ' [desc]'))
s = str(s or "")
order_by_choices.append((json.dumps([s, True]), s + " [asc]"))
order_by_choices.append((json.dumps([s, False]), s + " [desc]"))
verbose_map = {'__timestamp': 'Time'}
verbose_map.update({
o.metric_name: o.verbose_name or o.metric_name
for o in self.metrics
})
verbose_map.update({
o.column_name: o.verbose_name or o.column_name
for o in self.columns
})
verbose_map = {"__timestamp": "Time"}
verbose_map.update(
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
)
verbose_map.update(
{o.column_name: o.verbose_name or o.column_name for o in self.columns}
)
return {
# simple fields
'id': self.id,
'column_formats': self.column_formats,
'description': self.description,
'database': self.database.data, # pylint: disable=no-member
'default_endpoint': self.default_endpoint,
'filter_select': self.filter_select_enabled, # TODO deprecate
'filter_select_enabled': self.filter_select_enabled,
'name': self.name,
'datasource_name': self.datasource_name,
'type': self.type,
'schema': self.schema,
'offset': self.offset,
'cache_timeout': self.cache_timeout,
'params': self.params,
'perm': self.perm,
'edit_url': self.url,
"id": self.id,
"column_formats": self.column_formats,
"description": self.description,
"database": self.database.data, # pylint: disable=no-member
"default_endpoint": self.default_endpoint,
"filter_select": self.filter_select_enabled, # TODO deprecate
"filter_select_enabled": self.filter_select_enabled,
"name": self.name,
"datasource_name": self.datasource_name,
"type": self.type,
"schema": self.schema,
"offset": self.offset,
"cache_timeout": self.cache_timeout,
"params": self.params,
"perm": self.perm,
"edit_url": self.url,
# sqla-specific
'sql': self.sql,
"sql": self.sql,
# one to many
'columns': [o.data for o in self.columns],
'metrics': [o.data for o in self.metrics],
"columns": [o.data for o in self.columns],
"metrics": [o.data for o in self.metrics],
# TODO deprecate, move logic to JS
'order_by_choices': order_by_choices,
'owners': [owner.id for owner in self.owners],
'verbose_map': verbose_map,
'select_star': self.select_star,
"order_by_choices": order_by_choices,
"owners": [owner.id for owner in self.owners],
"verbose_map": verbose_map,
"select_star": self.select_star,
}
@staticmethod
def filter_values_handler(
values, target_column_is_numeric=False, is_list_target=False):
values, target_column_is_numeric=False, is_list_target=False
):
def handle_single_value(v):
# backward compatibility with previous <select> components
if isinstance(v, str):
v = v.strip('\t\n\'"')
v = v.strip("\t\n'\"")
if target_column_is_numeric:
# For backwards compatibility and edge cases
# where a column data type might have changed
v = utils.string_to_num(v)
if v == '<NULL>':
if v == "<NULL>":
return None
elif v == '<empty string>':
return ''
elif v == "<empty string>":
return ""
return v
if isinstance(values, (list, tuple)):
values = [handle_single_value(v) for v in values]
else:
@ -278,8 +269,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
if col.column_name == column_name:
return col
def get_fk_many_from_list(
self, object_list, fkmany, fkmany_class, key_attr):
def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr):
"""Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code"""
@ -302,13 +292,10 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
for obj in object_list:
key = obj.get(key_attr)
if key not in orm_keys:
del obj['id']
del obj["id"]
orm_kwargs = {}
for k in obj:
if (
k in fkmany_class.update_from_object_fields and
k in obj
):
if k in fkmany_class.update_from_object_fields and k in obj:
orm_kwargs[k] = obj[k]
new_obj = fkmany_class(**orm_kwargs)
new_fks.append(new_obj)
@ -329,16 +316,18 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
for attr in self.update_from_object_fields:
setattr(self, attr, obj.get(attr))
self.owners = obj.get('owners', [])
self.owners = obj.get("owners", [])
# Syncing metrics
metrics = self.get_fk_many_from_list(
obj.get('metrics'), self.metrics, self.metric_class, 'metric_name')
obj.get("metrics"), self.metrics, self.metric_class, "metric_name"
)
self.metrics = metrics
# Syncing columns
self.columns = self.get_fk_many_from_list(
obj.get('columns'), self.columns, self.column_class, 'column_name')
obj.get("columns"), self.columns, self.column_class, "column_name"
)
class BaseColumn(AuditMixinNullable, ImportMixin):
@ -363,32 +352,31 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
return self.column_name
num_types = (
'DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'NUMBER',
'LONG', 'REAL', 'NUMERIC', 'DECIMAL', 'MONEY',
"DOUBLE",
"FLOAT",
"INT",
"BIGINT",
"NUMBER",
"LONG",
"REAL",
"NUMERIC",
"DECIMAL",
"MONEY",
)
date_types = ('DATE', 'TIME', 'DATETIME')
str_types = ('VARCHAR', 'STRING', 'CHAR')
date_types = ("DATE", "TIME", "DATETIME")
str_types = ("VARCHAR", "STRING", "CHAR")
@property
def is_num(self):
return (
self.type and
any([t in self.type.upper() for t in self.num_types])
)
return self.type and any([t in self.type.upper() for t in self.num_types])
@property
def is_time(self):
return (
self.type and
any([t in self.type.upper() for t in self.date_types])
)
return self.type and any([t in self.type.upper() for t in self.date_types])
@property
def is_string(self):
return (
self.type and
any([t in self.type.upper() for t in self.str_types])
)
return self.type and any([t in self.type.upper() for t in self.str_types])
@property
def expression(self):
@ -397,9 +385,17 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
@property
def data(self):
attrs = (
'id', 'column_name', 'verbose_name', 'description', 'expression',
'filterable', 'groupby', 'is_dttm', 'type',
'database_expression', 'python_date_format',
"id",
"column_name",
"verbose_name",
"description",
"expression",
"filterable",
"groupby",
"is_dttm",
"type",
"database_expression",
"python_date_format",
)
return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
@ -432,6 +428,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
backref=backref('metrics', cascade='all, delete-orphan'),
enable_typechecks=False)
"""
@property
def perm(self):
raise NotImplementedError()
@ -443,6 +440,12 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
@property
def data(self):
attrs = (
'id', 'metric_name', 'verbose_name', 'description', 'expression',
'warning_text', 'd3format')
"id",
"metric_name",
"verbose_name",
"description",
"expression",
"warning_text",
"d3format",
)
return {s: getattr(self, s) for s in attrs}

View File

@ -24,7 +24,10 @@ from superset.views.base import SupersetModelView
class DatasourceModelView(SupersetModelView):
def pre_delete(self, obj):
if obj.slices:
raise SupersetException(Markup(
'Cannot delete a datasource that has slices attached to it.'
"Here's the list of associated charts: " +
''.join([o.slice_name for o in obj.slices])))
raise SupersetException(
Markup(
"Cannot delete a datasource that has slices attached to it."
"Here's the list of associated charts: "
+ "".join([o.slice_name for o in obj.slices])
)
)

View File

@ -51,15 +51,21 @@ class ConnectorRegistry(object):
return datasources
@classmethod
def get_datasource_by_name(cls, session, datasource_type, datasource_name,
schema, database_name):
def get_datasource_by_name(
cls, session, datasource_type, datasource_name, schema, database_name
):
datasource_class = ConnectorRegistry.sources[datasource_type]
datasources = session.query(datasource_class).all()
# Filter datasoures that don't have database.
db_ds = [d for d in datasources if d.database and
d.database.name == database_name and
d.name == datasource_name and schema == schema]
db_ds = [
d
for d in datasources
if d.database
and d.database.name == database_name
and d.name == datasource_name
and schema == schema
]
return db_ds[0]
@classmethod
@ -87,8 +93,8 @@ class ConnectorRegistry(object):
)
@classmethod
def query_datasources_by_name(
cls, session, database, datasource_name, schema=None):
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=None)
session, database, datasource_name, schema=None
)

File diff suppressed because it is too large Load Diff

View File

@ -33,9 +33,14 @@ from superset.connectors.base.views import DatasourceModelView
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils import core as utils
from superset.views.base import (
BaseSupersetView, DatasourceFilter, DeleteMixin,
get_datasource_exist_error_msg, ListWidgetWithCheckboxes, SupersetModelView,
validate_json, YamlExportMixin,
BaseSupersetView,
DatasourceFilter,
DeleteMixin,
get_datasource_exist_error_msg,
ListWidgetWithCheckboxes,
SupersetModelView,
validate_json,
YamlExportMixin,
)
from . import models
@ -43,48 +48,56 @@ from . import models
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidColumn)
list_title = _('Columns')
show_title = _('Show Druid Column')
add_title = _('Add Druid Column')
edit_title = _('Edit Druid Column')
list_title = _("Columns")
show_title = _("Show Druid Column")
add_title = _("Add Druid Column")
edit_title = _("Edit Druid Column")
list_widget = ListWidgetWithCheckboxes
edit_columns = [
'column_name', 'verbose_name', 'description', 'dimension_spec_json', 'datasource',
'groupby', 'filterable']
"column_name",
"verbose_name",
"description",
"dimension_spec_json",
"datasource",
"groupby",
"filterable",
]
add_columns = edit_columns
list_columns = ['column_name', 'verbose_name', 'type', 'groupby', 'filterable']
list_columns = ["column_name", "verbose_name", "type", "groupby", "filterable"]
can_delete = False
page_size = 500
label_columns = {
'column_name': _('Column'),
'type': _('Type'),
'datasource': _('Datasource'),
'groupby': _('Groupable'),
'filterable': _('Filterable'),
"column_name": _("Column"),
"type": _("Type"),
"datasource": _("Datasource"),
"groupby": _("Groupable"),
"filterable": _("Filterable"),
}
description_columns = {
'filterable': _(
'Whether this column is exposed in the `Filters` section '
'of the explore view.'),
'dimension_spec_json': utils.markdown(
'this field can be used to specify '
'a `dimensionSpec` as documented [here]'
'(http://druid.io/docs/latest/querying/dimensionspecs.html). '
'Make sure to input valid JSON and that the '
'`outputName` matches the `column_name` defined '
'above.',
True),
"filterable": _(
"Whether this column is exposed in the `Filters` section "
"of the explore view."
),
"dimension_spec_json": utils.markdown(
"this field can be used to specify "
"a `dimensionSpec` as documented [here]"
"(http://druid.io/docs/latest/querying/dimensionspecs.html). "
"Make sure to input valid JSON and that the "
"`outputName` matches the `column_name` defined "
"above.",
True,
),
}
add_form_extra_fields = {
'datasource': QuerySelectField(
'Datasource',
"datasource": QuerySelectField(
"Datasource",
query_factory=lambda: db.session().query(models.DruidDatasource),
allow_blank=True,
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
edit_form_extra_fields = add_form_extra_fields
@ -96,18 +109,20 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
try:
dimension_spec = json.loads(col.dimension_spec_json)
except ValueError as e:
raise ValueError('Invalid Dimension Spec JSON: ' + str(e))
raise ValueError("Invalid Dimension Spec JSON: " + str(e))
if not isinstance(dimension_spec, dict):
raise ValueError('Dimension Spec must be a JSON object')
if 'outputName' not in dimension_spec:
raise ValueError('Dimension Spec does not contain `outputName`')
if 'dimension' not in dimension_spec:
raise ValueError('Dimension Spec is missing `dimension`')
raise ValueError("Dimension Spec must be a JSON object")
if "outputName" not in dimension_spec:
raise ValueError("Dimension Spec does not contain `outputName`")
if "dimension" not in dimension_spec:
raise ValueError("Dimension Spec is missing `dimension`")
# `outputName` should be the same as the `column_name`
if dimension_spec['outputName'] != col.column_name:
if dimension_spec["outputName"] != col.column_name:
raise ValueError(
'`outputName` [{}] unequal to `column_name` [{}]'
.format(dimension_spec['outputName'], col.column_name))
"`outputName` [{}] unequal to `column_name` [{}]".format(
dimension_spec["outputName"], col.column_name
)
)
def post_update(self, col):
col.refresh_metrics()
@ -122,60 +137,73 @@ appbuilder.add_view_no_menu(DruidColumnInlineView)
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.DruidMetric)
list_title = _('Metrics')
show_title = _('Show Druid Metric')
add_title = _('Add Druid Metric')
edit_title = _('Edit Druid Metric')
list_title = _("Metrics")
show_title = _("Show Druid Metric")
add_title = _("Add Druid Metric")
edit_title = _("Edit Druid Metric")
list_columns = ['metric_name', 'verbose_name', 'metric_type']
list_columns = ["metric_name", "verbose_name", "metric_type"]
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
'datasource', 'd3format', 'is_restricted', 'warning_text']
"metric_name",
"description",
"verbose_name",
"metric_type",
"json",
"datasource",
"d3format",
"is_restricted",
"warning_text",
]
add_columns = edit_columns
page_size = 500
validators_columns = {
'json': [validate_json],
}
validators_columns = {"json": [validate_json]}
description_columns = {
'metric_type': utils.markdown(
'use `postagg` as the metric type if you are defining a '
'[Druid Post Aggregation]'
'(http://druid.io/docs/latest/querying/post-aggregations.html)',
True),
'is_restricted': _('Whether access to this metric is restricted '
'to certain roles. Only roles with the permission '
"'metric access on XXX (the name of this metric)' "
'are allowed to access this metric'),
"metric_type": utils.markdown(
"use `postagg` as the metric type if you are defining a "
"[Druid Post Aggregation]"
"(http://druid.io/docs/latest/querying/post-aggregations.html)",
True,
),
"is_restricted": _(
"Whether access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"
),
}
label_columns = {
'metric_name': _('Metric'),
'description': _('Description'),
'verbose_name': _('Verbose Name'),
'metric_type': _('Type'),
'json': _('JSON'),
'datasource': _('Druid Datasource'),
'warning_text': _('Warning Message'),
'is_restricted': _('Is Restricted'),
"metric_name": _("Metric"),
"description": _("Description"),
"verbose_name": _("Verbose Name"),
"metric_type": _("Type"),
"json": _("JSON"),
"datasource": _("Druid Datasource"),
"warning_text": _("Warning Message"),
"is_restricted": _("Is Restricted"),
}
add_form_extra_fields = {
'datasource': QuerySelectField(
'Datasource',
"datasource": QuerySelectField(
"Datasource",
query_factory=lambda: db.session().query(models.DruidDatasource),
allow_blank=True,
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
edit_form_extra_fields = add_form_extra_fields
def post_add(self, metric):
if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm())
security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
def post_update(self, metric):
if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm())
security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
appbuilder.add_view_no_menu(DruidMetricInlineView)
@ -184,57 +212,63 @@ appbuilder.add_view_no_menu(DruidMetricInlineView)
class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.DruidCluster)
list_title = _('Druid Clusters')
show_title = _('Show Druid Cluster')
add_title = _('Add Druid Cluster')
edit_title = _('Edit Druid Cluster')
list_title = _("Druid Clusters")
show_title = _("Show Druid Cluster")
add_title = _("Add Druid Cluster")
edit_title = _("Edit Druid Cluster")
add_columns = [
'verbose_name', 'broker_host', 'broker_port',
'broker_user', 'broker_pass', 'broker_endpoint',
'cache_timeout', 'cluster_name',
"verbose_name",
"broker_host",
"broker_port",
"broker_user",
"broker_pass",
"broker_endpoint",
"cache_timeout",
"cluster_name",
]
edit_columns = add_columns
list_columns = ['cluster_name', 'metadata_last_refreshed']
search_columns = ('cluster_name',)
list_columns = ["cluster_name", "metadata_last_refreshed"]
search_columns = ("cluster_name",)
label_columns = {
'cluster_name': _('Cluster'),
'broker_host': _('Broker Host'),
'broker_port': _('Broker Port'),
'broker_user': _('Broker Username'),
'broker_pass': _('Broker Password'),
'broker_endpoint': _('Broker Endpoint'),
'verbose_name': _('Verbose Name'),
'cache_timeout': _('Cache Timeout'),
'metadata_last_refreshed': _('Metadata Last Refreshed'),
"cluster_name": _("Cluster"),
"broker_host": _("Broker Host"),
"broker_port": _("Broker Port"),
"broker_user": _("Broker Username"),
"broker_pass": _("Broker Password"),
"broker_endpoint": _("Broker Endpoint"),
"verbose_name": _("Verbose Name"),
"cache_timeout": _("Cache Timeout"),
"metadata_last_refreshed": _("Metadata Last Refreshed"),
}
description_columns = {
'cache_timeout': _(
'Duration (in seconds) of the caching timeout for this cluster. '
'A timeout of 0 indicates that the cache never expires. '
'Note this defaults to the global timeout if undefined.'),
'broker_user': _(
'Druid supports basic authentication. See '
'[auth](http://druid.io/docs/latest/design/auth.html) and '
'druid-basic-security extension',
"cache_timeout": _(
"Duration (in seconds) of the caching timeout for this cluster. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the global timeout if undefined."
),
'broker_pass': _(
'Druid supports basic authentication. See '
'[auth](http://druid.io/docs/latest/design/auth.html) and '
'druid-basic-security extension',
"broker_user": _(
"Druid supports basic authentication. See "
"[auth](http://druid.io/docs/latest/design/auth.html) and "
"druid-basic-security extension"
),
"broker_pass": _(
"Druid supports basic authentication. See "
"[auth](http://druid.io/docs/latest/design/auth.html) and "
"druid-basic-security extension"
),
}
edit_form_extra_fields = {
'cluster_name': QuerySelectField(
'Cluster',
"cluster_name": QuerySelectField(
"Cluster",
query_factory=lambda: db.session().query(models.DruidCluster),
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
def pre_add(self, cluster):
security_manager.add_permission_view_menu('database_access', cluster.perm)
security_manager.add_permission_view_menu("database_access", cluster.perm)
def pre_update(self, cluster):
self.pre_add(cluster)
@ -245,112 +279,118 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): #
appbuilder.add_view(
DruidClusterModelView,
name='Druid Clusters',
label=__('Druid Clusters'),
icon='fa-cubes',
category='Sources',
category_label=__('Sources'),
category_icon='fa-database',
name="Druid Clusters",
label=__("Druid Clusters"),
icon="fa-cubes",
category="Sources",
category_label=__("Sources"),
category_icon="fa-database",
)
class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
class DruidDatasourceModelView(
DatasourceModelView, DeleteMixin, YamlExportMixin
): # noqa
datamodel = SQLAInterface(models.DruidDatasource)
list_title = _('Druid Datasources')
show_title = _('Show Druid Datasource')
add_title = _('Add Druid Datasource')
edit_title = _('Edit Druid Datasource')
list_title = _("Druid Datasources")
show_title = _("Show Druid Datasource")
add_title = _("Add Druid Datasource")
edit_title = _("Edit Druid Datasource")
list_columns = [
'datasource_link', 'cluster', 'changed_by_', 'modified']
order_columns = ['datasource_link', 'modified']
list_columns = ["datasource_link", "cluster", "changed_by_", "modified"]
order_columns = ["datasource_link", "modified"]
related_views = [DruidColumnInlineView, DruidMetricInlineView]
edit_columns = [
'datasource_name', 'cluster', 'description', 'owners',
'is_hidden',
'filter_select_enabled', 'fetch_values_from',
'default_endpoint', 'offset', 'cache_timeout']
search_columns = (
'datasource_name', 'cluster', 'description', 'owners',
)
"datasource_name",
"cluster",
"description",
"owners",
"is_hidden",
"filter_select_enabled",
"fetch_values_from",
"default_endpoint",
"offset",
"cache_timeout",
]
search_columns = ("datasource_name", "cluster", "description", "owners")
add_columns = edit_columns
show_columns = add_columns + ['perm', 'slices']
show_columns = add_columns + ["perm", "slices"]
page_size = 500
base_order = ('datasource_name', 'asc')
base_order = ("datasource_name", "asc")
description_columns = {
'slices': _(
'The list of charts associated with this table. By '
'altering this datasource, you may change how these associated '
'charts behave. '
'Also note that charts need to point to a datasource, so '
'this form will fail at saving if removing charts from a '
'datasource. If you want to change the datasource for a chart, '
"overwrite the chart from the 'explore view'"),
'offset': _('Timezone offset (in hours) for this datasource'),
'description': Markup(
"slices": _(
"The list of charts associated with this table. By "
"altering this datasource, you may change how these associated "
"charts behave. "
"Also note that charts need to point to a datasource, so "
"this form will fail at saving if removing charts from a "
"datasource. If you want to change the datasource for a chart, "
"overwrite the chart from the 'explore view'"
),
"offset": _("Timezone offset (in hours) for this datasource"),
"description": Markup(
'Supports <a href="'
'https://daringfireball.net/projects/markdown/">markdown</a>'),
'fetch_values_from': _(
'Time expression to use as a predicate when retrieving '
'distinct values to populate the filter component. '
'Only applies when `Enable Filter Select` is on. If '
'you enter `7 days ago`, the distinct list of values in '
'the filter will be populated based on the distinct value over '
'the past week'),
'filter_select_enabled': _(
'https://daringfireball.net/projects/markdown/">markdown</a>'
),
"fetch_values_from": _(
"Time expression to use as a predicate when retrieving "
"distinct values to populate the filter component. "
"Only applies when `Enable Filter Select` is on. If "
"you enter `7 days ago`, the distinct list of values in "
"the filter will be populated based on the distinct value over "
"the past week"
),
"filter_select_enabled": _(
"Whether to populate the filter's dropdown in the explore "
"view's filter section with a list of distinct values fetched "
'from the backend on the fly'),
'default_endpoint': _(
'Redirects to this endpoint when clicking on the datasource '
'from the datasource list'),
'cache_timeout': _(
'Duration (in seconds) of the caching timeout for this datasource. '
'A timeout of 0 indicates that the cache never expires. '
'Note this defaults to the cluster timeout if undefined.'),
"from the backend on the fly"
),
"default_endpoint": _(
"Redirects to this endpoint when clicking on the datasource "
"from the datasource list"
),
"cache_timeout": _(
"Duration (in seconds) of the caching timeout for this datasource. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the cluster timeout if undefined."
),
}
base_filters = [['id', DatasourceFilter, lambda: []]]
base_filters = [["id", DatasourceFilter, lambda: []]]
label_columns = {
'slices': _('Associated Charts'),
'datasource_link': _('Data Source'),
'cluster': _('Cluster'),
'description': _('Description'),
'owners': _('Owners'),
'is_hidden': _('Is Hidden'),
'filter_select_enabled': _('Enable Filter Select'),
'default_endpoint': _('Default Endpoint'),
'offset': _('Time Offset'),
'cache_timeout': _('Cache Timeout'),
'datasource_name': _('Datasource Name'),
'fetch_values_from': _('Fetch Values From'),
'changed_by_': _('Changed By'),
'modified': _('Modified'),
"slices": _("Associated Charts"),
"datasource_link": _("Data Source"),
"cluster": _("Cluster"),
"description": _("Description"),
"owners": _("Owners"),
"is_hidden": _("Is Hidden"),
"filter_select_enabled": _("Enable Filter Select"),
"default_endpoint": _("Default Endpoint"),
"offset": _("Time Offset"),
"cache_timeout": _("Cache Timeout"),
"datasource_name": _("Datasource Name"),
"fetch_values_from": _("Fetch Values From"),
"changed_by_": _("Changed By"),
"modified": _("Modified"),
}
def pre_add(self, datasource):
with db.session.no_autoflush:
query = (
db.session.query(models.DruidDatasource)
.filter(models.DruidDatasource.datasource_name ==
datasource.datasource_name,
models.DruidDatasource.cluster_name ==
datasource.cluster.id)
query = db.session.query(models.DruidDatasource).filter(
models.DruidDatasource.datasource_name == datasource.datasource_name,
models.DruidDatasource.cluster_name == datasource.cluster.id,
)
if db.session.query(query.exists()).scalar():
raise Exception(get_datasource_exist_error_msg(
datasource.full_name))
raise Exception(get_datasource_exist_error_msg(datasource.full_name))
def post_add(self, datasource):
datasource.refresh_metrics()
security_manager.add_permission_view_menu(
'datasource_access',
datasource.get_perm(),
"datasource_access", datasource.get_perm()
)
if datasource.schema:
security_manager.add_permission_view_menu(
'schema_access',
datasource.schema_perm,
"schema_access", datasource.schema_perm
)
def post_update(self, datasource):
@ -362,22 +402,23 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin
appbuilder.add_view(
DruidDatasourceModelView,
'Druid Datasources',
label=__('Druid Datasources'),
category='Sources',
category_label=__('Sources'),
icon='fa-cube')
"Druid Datasources",
label=__("Druid Datasources"),
category="Sources",
category_label=__("Sources"),
icon="fa-cube",
)
class Druid(BaseSupersetView):
"""The base views for Superset!"""
@has_access
@expose('/refresh_datasources/')
@expose("/refresh_datasources/")
def refresh_datasources(self, refreshAll=True):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
valid_cluster = True
@ -387,21 +428,25 @@ class Druid(BaseSupersetView):
valid_cluster = False
flash(
"Error while processing cluster '{}'\n{}".format(
cluster_name, utils.error_msg_from_exception(e)),
'danger')
cluster_name, utils.error_msg_from_exception(e)
),
"danger",
)
logging.exception(e)
pass
if valid_cluster:
cluster.metadata_last_refreshed = datetime.now()
flash(
_('Refreshed metadata from cluster [{}]').format(
cluster.cluster_name),
'info')
_("Refreshed metadata from cluster [{}]").format(
cluster.cluster_name
),
"info",
)
session.commit()
return redirect('/druiddatasourcemodelview/list/')
return redirect("/druiddatasourcemodelview/list/")
@has_access
@expose('/scan_new_datasources/')
@expose("/scan_new_datasources/")
def scan_new_datasources(self):
"""
Calling this endpoint will cause a scan for new
@ -413,21 +458,23 @@ class Druid(BaseSupersetView):
appbuilder.add_view_no_menu(Druid)
appbuilder.add_link(
'Scan New Datasources',
label=__('Scan New Datasources'),
href='/druid/scan_new_datasources/',
category='Sources',
category_label=__('Sources'),
category_icon='fa-database',
icon='fa-refresh')
"Scan New Datasources",
label=__("Scan New Datasources"),
href="/druid/scan_new_datasources/",
category="Sources",
category_label=__("Sources"),
category_icon="fa-database",
icon="fa-refresh",
)
appbuilder.add_link(
'Refresh Druid Metadata',
label=__('Refresh Druid Metadata'),
href='/druid/refresh_datasources/',
category='Sources',
category_label=__('Sources'),
category_icon='fa-database',
icon='fa-cog')
"Refresh Druid Metadata",
label=__("Refresh Druid Metadata"),
href="/druid/refresh_datasources/",
category="Sources",
category_label=__("Sources"),
category_icon="fa-database",
icon="fa-cog",
)
appbuilder.add_separator('Sources')
appbuilder.add_separator("Sources")

File diff suppressed because it is too large Load Diff

View File

@ -32,8 +32,12 @@ from superset import appbuilder, db, security_manager
from superset.connectors.base.views import DatasourceModelView
from superset.utils import core as utils
from superset.views.base import (
DatasourceFilter, DeleteMixin, get_datasource_exist_error_msg,
ListWidgetWithCheckboxes, SupersetModelView, YamlExportMixin,
DatasourceFilter,
DeleteMixin,
get_datasource_exist_error_msg,
ListWidgetWithCheckboxes,
SupersetModelView,
YamlExportMixin,
)
from . import models
@ -43,79 +47,103 @@ logger = logging.getLogger(__name__)
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.TableColumn)
list_title = _('Columns')
show_title = _('Show Column')
add_title = _('Add Column')
edit_title = _('Edit Column')
list_title = _("Columns")
show_title = _("Show Column")
add_title = _("Add Column")
edit_title = _("Edit Column")
can_delete = False
list_widget = ListWidgetWithCheckboxes
edit_columns = [
'column_name', 'verbose_name', 'description',
'type', 'groupby', 'filterable',
'table', 'expression',
'is_dttm', 'python_date_format', 'database_expression']
"column_name",
"verbose_name",
"description",
"type",
"groupby",
"filterable",
"table",
"expression",
"is_dttm",
"python_date_format",
"database_expression",
]
add_columns = edit_columns
list_columns = [
'column_name', 'verbose_name', 'type', 'groupby', 'filterable',
'is_dttm']
"column_name",
"verbose_name",
"type",
"groupby",
"filterable",
"is_dttm",
]
page_size = 500
description_columns = {
'is_dttm': _(
'Whether to make this column available as a '
'[Time Granularity] option, column has to be DATETIME or '
'DATETIME-like'),
'filterable': _(
'Whether this column is exposed in the `Filters` section '
'of the explore view.'),
'type': _(
'The data type that was inferred by the database. '
'It may be necessary to input a type manually for '
'expression-defined columns in some cases. In most case '
'users should not need to alter this.'),
'expression': utils.markdown(
'a valid, *non-aggregating* SQL expression as supported by the '
'underlying backend. Example: `substr(name, 1, 1)`', True),
'python_date_format': utils.markdown(Markup(
'The pattern of timestamp format, use '
'<a href="https://docs.python.org/2/library/'
'datetime.html#strftime-strptime-behavior">'
'python datetime string pattern</a> '
'expression. If time is stored in epoch '
'format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` '
'below empty if timestamp is stored in '
'String or Integer(epoch) type'), True),
'database_expression': utils.markdown(
'The database expression to cast internal datetime '
'constants to database date/timestamp type according to the DBAPI. '
'The expression should follow the pattern of '
'%Y-%m-%d %H:%M:%S, based on different DBAPI. '
'The string should be a python string formatter \n'
"is_dttm": _(
"Whether to make this column available as a "
"[Time Granularity] option, column has to be DATETIME or "
"DATETIME-like"
),
"filterable": _(
"Whether this column is exposed in the `Filters` section "
"of the explore view."
),
"type": _(
"The data type that was inferred by the database. "
"It may be necessary to input a type manually for "
"expression-defined columns in some cases. In most case "
"users should not need to alter this."
),
"expression": utils.markdown(
"a valid, *non-aggregating* SQL expression as supported by the "
"underlying backend. Example: `substr(name, 1, 1)`",
True,
),
"python_date_format": utils.markdown(
Markup(
"The pattern of timestamp format, use "
'<a href="https://docs.python.org/2/library/'
'datetime.html#strftime-strptime-behavior">'
"python datetime string pattern</a> "
"expression. If time is stored in epoch "
"format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
"below empty if timestamp is stored in "
"String or Integer(epoch) type"
),
True,
),
"database_expression": utils.markdown(
"The database expression to cast internal datetime "
"constants to database date/timestamp type according to the DBAPI. "
"The expression should follow the pattern of "
"%Y-%m-%d %H:%M:%S, based on different DBAPI. "
"The string should be a python string formatter \n"
"`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle "
'Superset uses default expression based on DB URI if this '
'field is blank.', True),
"Superset uses default expression based on DB URI if this "
"field is blank.",
True,
),
}
label_columns = {
'column_name': _('Column'),
'verbose_name': _('Verbose Name'),
'description': _('Description'),
'groupby': _('Groupable'),
'filterable': _('Filterable'),
'table': _('Table'),
'expression': _('Expression'),
'is_dttm': _('Is temporal'),
'python_date_format': _('Datetime Format'),
'database_expression': _('Database Expression'),
'type': _('Type'),
"column_name": _("Column"),
"verbose_name": _("Verbose Name"),
"description": _("Description"),
"groupby": _("Groupable"),
"filterable": _("Filterable"),
"table": _("Table"),
"expression": _("Expression"),
"is_dttm": _("Is temporal"),
"python_date_format": _("Datetime Format"),
"database_expression": _("Database Expression"),
"type": _("Type"),
}
add_form_extra_fields = {
'table': QuerySelectField(
'Table',
"table": QuerySelectField(
"Table",
query_factory=lambda: db.session().query(models.SqlaTable),
allow_blank=True,
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
edit_form_extra_fields = add_form_extra_fields
@ -127,63 +155,80 @@ appbuilder.add_view_no_menu(TableColumnInlineView)
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa
datamodel = SQLAInterface(models.SqlMetric)
list_title = _('Metrics')
show_title = _('Show Metric')
add_title = _('Add Metric')
edit_title = _('Edit Metric')
list_title = _("Metrics")
show_title = _("Show Metric")
add_title = _("Add Metric")
edit_title = _("Edit Metric")
list_columns = ['metric_name', 'verbose_name', 'metric_type']
list_columns = ["metric_name", "verbose_name", "metric_type"]
edit_columns = [
'metric_name', 'description', 'verbose_name', 'metric_type',
'expression', 'table', 'd3format', 'is_restricted', 'warning_text']
"metric_name",
"description",
"verbose_name",
"metric_type",
"expression",
"table",
"d3format",
"is_restricted",
"warning_text",
]
description_columns = {
'expression': utils.markdown(
'a valid, *aggregating* SQL expression as supported by the '
'underlying backend. Example: `count(DISTINCT userid)`', True),
'is_restricted': _('Whether access to this metric is restricted '
'to certain roles. Only roles with the permission '
"'metric access on XXX (the name of this metric)' "
'are allowed to access this metric'),
'd3format': utils.markdown(
'd3 formatting string as defined [here]'
'(https://github.com/d3/d3-format/blob/master/README.md#format). '
'For instance, this default formatting applies in the Table '
'visualization and allow for different metric to use different '
'formats', True,
"expression": utils.markdown(
"a valid, *aggregating* SQL expression as supported by the "
"underlying backend. Example: `count(DISTINCT userid)`",
True,
),
"is_restricted": _(
"Whether access to this metric is restricted "
"to certain roles. Only roles with the permission "
"'metric access on XXX (the name of this metric)' "
"are allowed to access this metric"
),
"d3format": utils.markdown(
"d3 formatting string as defined [here]"
"(https://github.com/d3/d3-format/blob/master/README.md#format). "
"For instance, this default formatting applies in the Table "
"visualization and allow for different metric to use different "
"formats",
True,
),
}
add_columns = edit_columns
page_size = 500
label_columns = {
'metric_name': _('Metric'),
'description': _('Description'),
'verbose_name': _('Verbose Name'),
'metric_type': _('Type'),
'expression': _('SQL Expression'),
'table': _('Table'),
'd3format': _('D3 Format'),
'is_restricted': _('Is Restricted'),
'warning_text': _('Warning Message'),
"metric_name": _("Metric"),
"description": _("Description"),
"verbose_name": _("Verbose Name"),
"metric_type": _("Type"),
"expression": _("SQL Expression"),
"table": _("Table"),
"d3format": _("D3 Format"),
"is_restricted": _("Is Restricted"),
"warning_text": _("Warning Message"),
}
add_form_extra_fields = {
'table': QuerySelectField(
'Table',
"table": QuerySelectField(
"Table",
query_factory=lambda: db.session().query(models.SqlaTable),
allow_blank=True,
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
edit_form_extra_fields = add_form_extra_fields
def post_add(self, metric):
if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm())
security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
def post_update(self, metric):
if metric.is_restricted:
security_manager.add_permission_view_menu('metric_access', metric.get_perm())
security_manager.add_permission_view_menu(
"metric_access", metric.get_perm()
)
appbuilder.add_view_no_menu(SqlMetricInlineView)
@ -192,104 +237,114 @@ appbuilder.add_view_no_menu(SqlMetricInlineView)
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable)
list_title = _('Tables')
show_title = _('Show Table')
add_title = _('Import a table definition')
edit_title = _('Edit Table')
list_title = _("Tables")
show_title = _("Show Table")
add_title = _("Import a table definition")
edit_title = _("Edit Table")
list_columns = [
'link', 'database_name',
'changed_by_', 'modified']
order_columns = ['modified']
add_columns = ['database', 'schema', 'table_name']
list_columns = ["link", "database_name", "changed_by_", "modified"]
order_columns = ["modified"]
add_columns = ["database", "schema", "table_name"]
edit_columns = [
'table_name', 'sql', 'filter_select_enabled',
'fetch_values_predicate', 'database', 'schema',
'description', 'owners',
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout',
'is_sqllab_view', 'template_params',
"table_name",
"sql",
"filter_select_enabled",
"fetch_values_predicate",
"database",
"schema",
"description",
"owners",
"main_dttm_col",
"default_endpoint",
"offset",
"cache_timeout",
"is_sqllab_view",
"template_params",
]
base_filters = [['id', DatasourceFilter, lambda: []]]
show_columns = edit_columns + ['perm', 'slices']
base_filters = [["id", DatasourceFilter, lambda: []]]
show_columns = edit_columns + ["perm", "slices"]
related_views = [TableColumnInlineView, SqlMetricInlineView]
base_order = ('changed_on', 'desc')
search_columns = (
'database', 'schema', 'table_name', 'owners', 'is_sqllab_view',
)
base_order = ("changed_on", "desc")
search_columns = ("database", "schema", "table_name", "owners", "is_sqllab_view")
description_columns = {
'slices': _(
'The list of charts associated with this table. By '
'altering this datasource, you may change how these associated '
'charts behave. '
'Also note that charts need to point to a datasource, so '
'this form will fail at saving if removing charts from a '
'datasource. If you want to change the datasource for a chart, '
"overwrite the chart from the 'explore view'"),
'offset': _('Timezone offset (in hours) for this datasource'),
'table_name': _(
'Name of the table that exists in the source database'),
'schema': _(
'Schema, as used only in some databases like Postgres, Redshift '
'and DB2'),
'description': Markup(
"slices": _(
"The list of charts associated with this table. By "
"altering this datasource, you may change how these associated "
"charts behave. "
"Also note that charts need to point to a datasource, so "
"this form will fail at saving if removing charts from a "
"datasource. If you want to change the datasource for a chart, "
"overwrite the chart from the 'explore view'"
),
"offset": _("Timezone offset (in hours) for this datasource"),
"table_name": _("Name of the table that exists in the source database"),
"schema": _(
"Schema, as used only in some databases like Postgres, Redshift " "and DB2"
),
"description": Markup(
'Supports <a href="https://daringfireball.net/projects/markdown/">'
'markdown</a>'),
'sql': _(
'This fields acts a Superset view, meaning that Superset will '
'run a query against this string as a subquery.',
"markdown</a>"
),
'fetch_values_predicate': _(
'Predicate applied when fetching distinct value to '
'populate the filter control component. Supports '
'jinja template syntax. Applies only when '
'`Enable Filter Select` is on.',
"sql": _(
"This fields acts a Superset view, meaning that Superset will "
"run a query against this string as a subquery."
),
'default_endpoint': _(
'Redirects to this endpoint when clicking on the table '
'from the table list'),
'filter_select_enabled': _(
"fetch_values_predicate": _(
"Predicate applied when fetching distinct value to "
"populate the filter control component. Supports "
"jinja template syntax. Applies only when "
"`Enable Filter Select` is on."
),
"default_endpoint": _(
"Redirects to this endpoint when clicking on the table "
"from the table list"
),
"filter_select_enabled": _(
"Whether to populate the filter's dropdown in the explore "
"view's filter section with a list of distinct values fetched "
'from the backend on the fly'),
'is_sqllab_view': _(
"Whether the table was generated by the 'Visualize' flow "
'in SQL Lab'),
'template_params': _(
'A set of parameters that become available in the query using '
'Jinja templating syntax'),
'cache_timeout': _(
'Duration (in seconds) of the caching timeout for this table. '
'A timeout of 0 indicates that the cache never expires. '
'Note this defaults to the database timeout if undefined.'),
"from the backend on the fly"
),
"is_sqllab_view": _(
"Whether the table was generated by the 'Visualize' flow " "in SQL Lab"
),
"template_params": _(
"A set of parameters that become available in the query using "
"Jinja templating syntax"
),
"cache_timeout": _(
"Duration (in seconds) of the caching timeout for this table. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the database timeout if undefined."
),
}
label_columns = {
'slices': _('Associated Charts'),
'link': _('Table'),
'changed_by_': _('Changed By'),
'database': _('Database'),
'database_name': _('Database'),
'changed_on_': _('Last Changed'),
'filter_select_enabled': _('Enable Filter Select'),
'schema': _('Schema'),
'default_endpoint': _('Default Endpoint'),
'offset': _('Offset'),
'cache_timeout': _('Cache Timeout'),
'table_name': _('Table Name'),
'fetch_values_predicate': _('Fetch Values Predicate'),
'owners': _('Owners'),
'main_dttm_col': _('Main Datetime Column'),
'description': _('Description'),
'is_sqllab_view': _('SQL Lab View'),
'template_params': _('Template parameters'),
'modified': _('Modified'),
"slices": _("Associated Charts"),
"link": _("Table"),
"changed_by_": _("Changed By"),
"database": _("Database"),
"database_name": _("Database"),
"changed_on_": _("Last Changed"),
"filter_select_enabled": _("Enable Filter Select"),
"schema": _("Schema"),
"default_endpoint": _("Default Endpoint"),
"offset": _("Offset"),
"cache_timeout": _("Cache Timeout"),
"table_name": _("Table Name"),
"fetch_values_predicate": _("Fetch Values Predicate"),
"owners": _("Owners"),
"main_dttm_col": _("Main Datetime Column"),
"description": _("Description"),
"is_sqllab_view": _("SQL Lab View"),
"template_params": _("Template parameters"),
"modified": _("Modified"),
}
edit_form_extra_fields = {
'database': QuerySelectField(
'Database',
"database": QuerySelectField(
"Database",
query_factory=lambda: db.session().query(models.Database),
widget=Select2Widget(extra_classes='readonly'),
),
widget=Select2Widget(extra_classes="readonly"),
)
}
def pre_add(self, table):
@ -297,34 +352,43 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
table_query = db.session.query(models.SqlaTable).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id)
models.SqlaTable.database_id == table.database.id,
)
if db.session.query(table_query.exists()).scalar():
raise Exception(
get_datasource_exist_error_msg(table.full_name))
raise Exception(get_datasource_exist_error_msg(table.full_name))
# Fail before adding if the table can't be found
try:
table.get_sqla_table_object()
except Exception as e:
logger.exception(f'Got an error in pre_add for {table.name}')
raise Exception(_(
'Table [{}] could not be found, '
'please double check your '
'database connection, schema, and '
'table name, error: {}').format(table.name, str(e)))
logger.exception(f"Got an error in pre_add for {table.name}")
raise Exception(
_(
"Table [{}] could not be found, "
"please double check your "
"database connection, schema, and "
"table name, error: {}"
).format(table.name, str(e))
)
def post_add(self, table, flash_message=True):
table.fetch_metadata()
security_manager.add_permission_view_menu('datasource_access', table.get_perm())
security_manager.add_permission_view_menu("datasource_access", table.get_perm())
if table.schema:
security_manager.add_permission_view_menu('schema_access', table.schema_perm)
security_manager.add_permission_view_menu(
"schema_access", table.schema_perm
)
if flash_message:
flash(_(
'The table was created. '
'As part of this two-phase configuration '
'process, you should now click the edit button by '
'the new table to configure it.'), 'info')
flash(
_(
"The table was created. "
"As part of this two-phase configuration "
"process, you should now click the edit button by "
"the new table to configure it."
),
"info",
)
def post_update(self, table):
self.post_add(table, flash_message=False)
@ -332,20 +396,18 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
def _delete(self, pk):
DeleteMixin._delete(self, pk)
@expose('/edit/<pk>', methods=['GET', 'POST'])
@expose("/edit/<pk>", methods=["GET", "POST"])
@has_access
def edit(self, pk):
"""Simple hack to redirect to explore view after saving"""
resp = super(TableModelView, self).edit(pk)
if isinstance(resp, str):
return resp
return redirect('/superset/explore/table/{}/'.format(pk))
return redirect("/superset/explore/table/{}/".format(pk))
@action(
'refresh',
__('Refresh Metadata'),
__('Refresh column metadata'),
'fa-refresh')
"refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh"
)
def refresh(self, tables):
if not isinstance(tables, list):
tables = [tables]
@ -360,26 +422,29 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): # noqa
if len(successes) > 0:
success_msg = _(
'Metadata refreshed for the following table(s): %(tables)s',
tables=', '.join([t.table_name for t in successes]))
flash(success_msg, 'info')
"Metadata refreshed for the following table(s): %(tables)s",
tables=", ".join([t.table_name for t in successes]),
)
flash(success_msg, "info")
if len(failures) > 0:
failure_msg = _(
'Unable to retrieve metadata for the following table(s): %(tables)s',
tables=', '.join([t.table_name for t in failures]))
flash(failure_msg, 'danger')
"Unable to retrieve metadata for the following table(s): %(tables)s",
tables=", ".join([t.table_name for t in failures]),
)
flash(failure_msg, "danger")
return redirect('/tablemodelview/list/')
return redirect("/tablemodelview/list/")
appbuilder.add_view_no_menu(TableModelView)
appbuilder.add_link(
'Tables',
label=__('Tables'),
href='/tablemodelview/list/?_flt_1_is_sqllab_view=y',
icon='fa-table',
category='Sources',
category_label=__('Sources'),
category_icon='fa-table')
"Tables",
label=__("Tables"),
href="/tablemodelview/list/?_flt_1_is_sqllab_view=y",
icon="fa-table",
category="Sources",
category_label=__("Sources"),
category_icon="fa-table",
)
appbuilder.add_separator('Sources')
appbuilder.add_separator("Sources")

View File

@ -28,6 +28,6 @@ from .multiformat_time_series import load_multiformat_time_series # noqa
from .paris import load_paris_iris_geojson # noqa
from .random_time_series import load_random_time_series_data # noqa
from .sf_population_polygons import load_sf_population_polygons # noqa
from .tabbed_dashboard import load_tabbed_dashboard # noqa
from .tabbed_dashboard import load_tabbed_dashboard # noqa
from .unicode_test_data import load_unicode_test_data # noqa
from .world_bank import load_world_bank_health_n_pop # noqa

View File

@ -26,30 +26,31 @@ from .helpers import TBL, get_example_data
def load_bart_lines():
tbl_name = 'bart_lines'
content = get_example_data('bart-lines.json.gz')
df = pd.read_json(content, encoding='latin-1')
df['path_json'] = df.path.map(json.dumps)
df['polyline'] = df.path.map(polyline.encode)
del df['path']
tbl_name = "bart_lines"
content = get_example_data("bart-lines.json.gz")
df = pd.read_json(content, encoding="latin-1")
df["path_json"] = df.path.map(json.dumps)
df["polyline"] = df.path.map(polyline.encode)
del df["path"]
df.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'color': String(255),
'name': String(255),
'polyline': Text,
'path_json': Text,
"color": String(255),
"name": String(255),
"polyline": Text,
"path_json": Text,
},
index=False)
print('Creating table {} reference'.format(tbl_name))
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = 'BART lines'
tbl.description = "BART lines"
tbl.database = get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()

View File

@ -38,46 +38,46 @@ from .helpers import (
def load_birth_names():
"""Loading birth name dataset from a zip file in the repo"""
data = get_example_data('birth_names.json.gz')
data = get_example_data("birth_names.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='ms')
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf.to_sql(
'birth_names',
"birth_names",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'ds': DateTime,
'gender': String(16),
'state': String(10),
'name': String(255),
"ds": DateTime,
"gender": String(16),
"state": String(10),
"name": String(255),
},
index=False)
print('Done loading table!')
print('-' * 80)
index=False,
)
print("Done loading table!")
print("-" * 80)
print('Creating table [birth_names] reference')
obj = db.session.query(TBL).filter_by(table_name='birth_names').first()
print("Creating table [birth_names] reference")
obj = db.session.query(TBL).filter_by(table_name="birth_names").first()
if not obj:
obj = TBL(table_name='birth_names')
obj.main_dttm_col = 'ds'
obj = TBL(table_name="birth_names")
obj.main_dttm_col = "ds"
obj.database = get_or_create_main_db()
obj.filter_select_enabled = True
if not any(col.column_name == 'num_california' for col in obj.columns):
col_state = str(column('state').compile(db.engine))
col_num = str(column('num').compile(db.engine))
obj.columns.append(TableColumn(
column_name='num_california',
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
))
if not any(col.column_name == "num_california" for col in obj.columns):
col_state = str(column("state").compile(db.engine))
col_num = str(column("num").compile(db.engine))
obj.columns.append(
TableColumn(
column_name="num_california",
expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
)
)
if not any(col.metric_name == 'sum__num' for col in obj.metrics):
col = str(column('num').compile(db.engine))
obj.metrics.append(SqlMetric(
metric_name='sum__num',
expression=f'SUM({col})',
))
if not any(col.metric_name == "sum__num" for col in obj.metrics):
col = str(column("num").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))
db.session.merge(obj)
db.session.commit()
@ -85,149 +85,149 @@ def load_birth_names():
tbl = obj
defaults = {
'compare_lag': '10',
'compare_suffix': 'o10Y',
'limit': '25',
'granularity_sqla': 'ds',
'groupby': [],
'metric': 'sum__num',
'metrics': ['sum__num'],
'row_limit': config.get('ROW_LIMIT'),
'since': '100 years ago',
'until': 'now',
'viz_type': 'table',
'where': '',
'markup_type': 'markdown',
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "ds",
"groupby": [],
"metric": "sum__num",
"metrics": ["sum__num"],
"row_limit": config.get("ROW_LIMIT"),
"since": "100 years ago",
"until": "now",
"viz_type": "table",
"where": "",
"markup_type": "markdown",
}
admin = security_manager.find_user('admin')
admin = security_manager.find_user("admin")
print('Creating some slices')
print("Creating some slices")
slices = [
Slice(
slice_name='Girls',
viz_type='table',
datasource_type='table',
slice_name="Girls",
viz_type="table",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
groupby=['name'],
filters=[{
'col': 'gender',
'op': 'in',
'val': ['girl'],
}],
groupby=["name"],
filters=[{"col": "gender", "op": "in", "val": ["girl"]}],
row_limit=50,
timeseries_limit_metric='sum__num')),
timeseries_limit_metric="sum__num",
),
),
Slice(
slice_name='Boys',
viz_type='table',
datasource_type='table',
slice_name="Boys",
viz_type="table",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
groupby=['name'],
filters=[{
'col': 'gender',
'op': 'in',
'val': ['boy'],
}],
row_limit=50)),
groupby=["name"],
filters=[{"col": "gender", "op": "in", "val": ["boy"]}],
row_limit=50,
),
),
Slice(
slice_name='Participants',
viz_type='big_number',
datasource_type='table',
slice_name="Participants",
viz_type="big_number",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='big_number', granularity_sqla='ds',
compare_lag='5', compare_suffix='over 5Y')),
viz_type="big_number",
granularity_sqla="ds",
compare_lag="5",
compare_suffix="over 5Y",
),
),
Slice(
slice_name='Genders',
viz_type='pie',
datasource_type='table',
slice_name="Genders",
viz_type="pie",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='pie', groupby=['gender'])),
params=get_slice_json(defaults, viz_type="pie", groupby=["gender"]),
),
Slice(
slice_name='Genders by State',
viz_type='dist_bar',
datasource_type='table',
slice_name="Genders by State",
viz_type="dist_bar",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
adhoc_filters=[
{
'clause': 'WHERE',
'expressionType': 'SIMPLE',
'filterOptionName': '2745eae5',
'comparator': ['other'],
'operator': 'not in',
'subject': 'state',
},
"clause": "WHERE",
"expressionType": "SIMPLE",
"filterOptionName": "2745eae5",
"comparator": ["other"],
"operator": "not in",
"subject": "state",
}
],
viz_type='dist_bar',
viz_type="dist_bar",
metrics=[
{
'expressionType': 'SIMPLE',
'column': {
'column_name': 'sum_boys',
'type': 'BIGINT(20)',
},
'aggregate': 'SUM',
'label': 'Boys',
'optionName': 'metric_11',
"expressionType": "SIMPLE",
"column": {"column_name": "sum_boys", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "Boys",
"optionName": "metric_11",
},
{
'expressionType': 'SIMPLE',
'column': {
'column_name': 'sum_girls',
'type': 'BIGINT(20)',
},
'aggregate': 'SUM',
'label': 'Girls',
'optionName': 'metric_12',
"expressionType": "SIMPLE",
"column": {"column_name": "sum_girls", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "Girls",
"optionName": "metric_12",
},
],
groupby=['state'])),
groupby=["state"],
),
),
Slice(
slice_name='Trends',
viz_type='line',
datasource_type='table',
slice_name="Trends",
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='line', groupby=['name'],
granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
viz_type="line",
groupby=["name"],
granularity_sqla="ds",
rich_tooltip=True,
show_legend=True,
),
),
Slice(
slice_name='Average and Sum Trends',
viz_type='dual_line',
datasource_type='table',
slice_name="Average and Sum Trends",
viz_type="dual_line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='dual_line',
viz_type="dual_line",
metric={
'expressionType': 'SIMPLE',
'column': {
'column_name': 'num',
'type': 'BIGINT(20)',
},
'aggregate': 'AVG',
'label': 'AVG(num)',
'optionName': 'metric_vgops097wej_g8uff99zhk7',
"expressionType": "SIMPLE",
"column": {"column_name": "num", "type": "BIGINT(20)"},
"aggregate": "AVG",
"label": "AVG(num)",
"optionName": "metric_vgops097wej_g8uff99zhk7",
},
metric_2='sum__num',
granularity_sqla='ds')),
metric_2="sum__num",
granularity_sqla="ds",
),
),
Slice(
slice_name='Title',
viz_type='markup',
datasource_type='table',
slice_name="Title",
viz_type="markup",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='markup', markup_type='html',
viz_type="markup",
markup_type="html",
code="""\
<div style='text-align:center'>
<h1>Birth Names Dashboard</h1>
@ -237,135 +237,156 @@ def load_birth_names():
</p>
<img src='/static/assets/images/babytux.jpg'>
</div>
""")),
""",
),
),
Slice(
slice_name='Name Cloud',
viz_type='word_cloud',
datasource_type='table',
slice_name="Name Cloud",
viz_type="word_cloud",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='word_cloud', size_from='10',
series='name', size_to='70', rotation='square',
limit='100')),
viz_type="word_cloud",
size_from="10",
series="name",
size_to="70",
rotation="square",
limit="100",
),
),
Slice(
slice_name='Pivot Table',
viz_type='pivot_table',
datasource_type='table',
slice_name="Pivot Table",
viz_type="pivot_table",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='pivot_table', metrics=['sum__num'],
groupby=['name'], columns=['state'])),
viz_type="pivot_table",
metrics=["sum__num"],
groupby=["name"],
columns=["state"],
),
),
Slice(
slice_name='Number of Girls',
viz_type='big_number_total',
datasource_type='table',
slice_name="Number of Girls",
viz_type="big_number_total",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='big_number_total', granularity_sqla='ds',
filters=[{
'col': 'gender',
'op': 'in',
'val': ['girl'],
}],
subheader='total female participants')),
viz_type="big_number_total",
granularity_sqla="ds",
filters=[{"col": "gender", "op": "in", "val": ["girl"]}],
subheader="total female participants",
),
),
Slice(
slice_name='Number of California Births',
viz_type='big_number_total',
datasource_type='table',
slice_name="Number of California Births",
viz_type="big_number_total",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
metric={
'expressionType': 'SIMPLE',
'column': {
'column_name': 'num_california',
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
'aggregate': 'SUM',
'label': 'SUM(num_california)',
"aggregate": "SUM",
"label": "SUM(num_california)",
},
viz_type='big_number_total',
granularity_sqla='ds')),
viz_type="big_number_total",
granularity_sqla="ds",
),
),
Slice(
slice_name='Top 10 California Names Timeseries',
viz_type='line',
datasource_type='table',
slice_name="Top 10 California Names Timeseries",
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
metrics=[{
'expressionType': 'SIMPLE',
'column': {
'column_name': 'num_california',
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
'aggregate': 'SUM',
'label': 'SUM(num_california)',
}],
viz_type='line',
granularity_sqla='ds',
groupby=['name'],
metrics=[
{
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
"aggregate": "SUM",
"label": "SUM(num_california)",
}
],
viz_type="line",
granularity_sqla="ds",
groupby=["name"],
timeseries_limit_metric={
'expressionType': 'SIMPLE',
'column': {
'column_name': 'num_california',
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
'aggregate': 'SUM',
'label': 'SUM(num_california)',
"aggregate": "SUM",
"label": "SUM(num_california)",
},
limit='10')),
limit="10",
),
),
Slice(
slice_name='Names Sorted by Num in California',
viz_type='table',
datasource_type='table',
slice_name="Names Sorted by Num in California",
viz_type="table",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
groupby=['name'],
groupby=["name"],
row_limit=50,
timeseries_limit_metric={
'expressionType': 'SIMPLE',
'column': {
'column_name': 'num_california',
'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"expressionType": "SIMPLE",
"column": {
"column_name": "num_california",
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
},
'aggregate': 'SUM',
'label': 'SUM(num_california)',
})),
"aggregate": "SUM",
"label": "SUM(num_california)",
},
),
),
Slice(
slice_name='Num Births Trend',
viz_type='line',
datasource_type='table',
slice_name="Num Births Trend",
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='line')),
params=get_slice_json(defaults, viz_type="line"),
),
Slice(
slice_name='Daily Totals',
viz_type='table',
datasource_type='table',
slice_name="Daily Totals",
viz_type="table",
datasource_type="table",
datasource_id=tbl.id,
created_by=admin,
params=get_slice_json(
defaults,
groupby=['ds'],
since='40 years ago',
until='now',
viz_type='table')),
groupby=["ds"],
since="40 years ago",
until="now",
viz_type="table",
),
),
]
for slc in slices:
merge_slice(slc)
print('Creating a dashboard')
dash = db.session.query(Dash).filter_by(dashboard_title='Births').first()
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title="Births").first()
if not dash:
dash = Dash()
js = textwrap.dedent("""\
js = textwrap.dedent(
# pylint: disable=line-too-long
"""\
{
"CHART-0dd270f0": {
"meta": {
@ -614,13 +635,15 @@ def load_birth_names():
},
"DASHBOARD_VERSION_KEY": "v2"
}
""")
"""
# pylint: enable=line-too-long
)
pos = json.loads(js)
# dashboard v2 doesn't allow add markup slice
dash.slices = [slc for slc in slices if slc.viz_type != 'markup']
dash.slices = [slc for slc in slices if slc.viz_type != "markup"]
update_slice_ids(pos, dash.slices)
dash.dashboard_title = 'Births'
dash.dashboard_title = "Births"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = 'births'
dash.slug = "births"
db.session.merge(dash)
db.session.commit()

File diff suppressed because it is too large Load Diff

View File

@ -36,75 +36,71 @@ from .helpers import (
def load_country_map_data():
"""Loading data for map with country map"""
csv_bytes = get_example_data(
'birth_france_data_for_country_map.csv', is_gzip=False, make_bytes=True)
data = pd.read_csv(csv_bytes, encoding='utf-8')
data['dttm'] = datetime.datetime.now().date()
"birth_france_data_for_country_map.csv", is_gzip=False, make_bytes=True
)
data = pd.read_csv(csv_bytes, encoding="utf-8")
data["dttm"] = datetime.datetime.now().date()
data.to_sql( # pylint: disable=no-member
'birth_france_by_region',
"birth_france_by_region",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'DEPT_ID': String(10),
'2003': BigInteger,
'2004': BigInteger,
'2005': BigInteger,
'2006': BigInteger,
'2007': BigInteger,
'2008': BigInteger,
'2009': BigInteger,
'2010': BigInteger,
'2011': BigInteger,
'2012': BigInteger,
'2013': BigInteger,
'2014': BigInteger,
'dttm': Date(),
"DEPT_ID": String(10),
"2003": BigInteger,
"2004": BigInteger,
"2005": BigInteger,
"2006": BigInteger,
"2007": BigInteger,
"2008": BigInteger,
"2009": BigInteger,
"2010": BigInteger,
"2011": BigInteger,
"2012": BigInteger,
"2013": BigInteger,
"2014": BigInteger,
"dttm": Date(),
},
index=False)
print('Done loading table!')
print('-' * 80)
print('Creating table reference')
obj = db.session.query(TBL).filter_by(table_name='birth_france_by_region').first()
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name="birth_france_by_region").first()
if not obj:
obj = TBL(table_name='birth_france_by_region')
obj.main_dttm_col = 'dttm'
obj = TBL(table_name="birth_france_by_region")
obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db()
if not any(col.metric_name == 'avg__2004' for col in obj.metrics):
col = str(column('2004').compile(db.engine))
obj.metrics.append(SqlMetric(
metric_name='avg__2004',
expression=f'AVG({col})',
))
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
slice_data = {
'granularity_sqla': '',
'since': '',
'until': '',
'where': '',
'viz_type': 'country_map',
'entity': 'DEPT_ID',
'metric': {
'expressionType': 'SIMPLE',
'column': {
'type': 'INT',
'column_name': '2004',
},
'aggregate': 'AVG',
'label': 'Boys',
'optionName': 'metric_112342',
"granularity_sqla": "",
"since": "",
"until": "",
"where": "",
"viz_type": "country_map",
"entity": "DEPT_ID",
"metric": {
"expressionType": "SIMPLE",
"column": {"type": "INT", "column_name": "2004"},
"aggregate": "AVG",
"label": "Boys",
"optionName": "metric_112342",
},
'row_limit': 500000,
"row_limit": 500000,
}
print('Creating a slice')
print("Creating a slice")
slc = Slice(
slice_name='Birth in France by department in 2016',
viz_type='country_map',
datasource_type='table',
slice_name="Birth in France by department in 2016",
viz_type="country_map",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -22,12 +22,13 @@ from superset.models.core import CssTemplate
def load_css_templates():
"""Loads 2 css templates to demonstrate the feature"""
print('Creating default CSS templates')
print("Creating default CSS templates")
obj = db.session.query(CssTemplate).filter_by(template_name='Flat').first()
obj = db.session.query(CssTemplate).filter_by(template_name="Flat").first()
if not obj:
obj = CssTemplate(template_name='Flat')
css = textwrap.dedent("""\
obj = CssTemplate(template_name="Flat")
css = textwrap.dedent(
"""\
.gridster div.widget {
transition: background-color 0.5s ease;
background-color: #FAFAFA;
@ -58,16 +59,17 @@ def load_css_templates():
'#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e',
];
*/
""")
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()
obj = (
db.session.query(CssTemplate).filter_by(template_name='Courier Black').first())
obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj:
obj = CssTemplate(template_name='Courier Black')
css = textwrap.dedent("""\
obj = CssTemplate(template_name="Courier Black")
css = textwrap.dedent(
"""\
.gridster div.widget {
transition: background-color 0.5s ease;
background-color: #EEE;
@ -113,7 +115,8 @@ def load_css_templates():
'#ff3339', '#ff1ab1', '#005c66', '#00b3a5', '#55d12e', '#b37e00', '#988b4e',
];
*/
""")
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()

View File

@ -18,21 +18,9 @@
import json
from superset import db
from .helpers import (
Dash,
get_slice_json,
merge_slice,
Slice,
TBL,
update_slice_ids,
)
from .helpers import Dash, get_slice_json, merge_slice, Slice, TBL, update_slice_ids
COLOR_RED = {
'r': 205,
'g': 0,
'b': 3,
'a': 0.82,
}
COLOR_RED = {"r": 205, "g": 0, "b": 3, "a": 0.82}
POSITION_JSON = """\
{
"CHART-3afd9d70": {
@ -177,46 +165,42 @@ POSITION_JSON = """\
def load_deck_dash():
print('Loading deck.gl dashboard')
print("Loading deck.gl dashboard")
slices = []
tbl = db.session.query(TBL).filter_by(table_name='long_lat').first()
tbl = db.session.query(TBL).filter_by(table_name="long_lat").first()
slice_data = {
'spatial': {
'type': 'latlong',
'lonCol': 'LON',
'latCol': 'LAT',
"spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
"color_picker": COLOR_RED,
"datasource": "5__table",
"filters": [],
"granularity_sqla": None,
"groupby": [],
"having": "",
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"multiplier": 10,
"point_radius_fixed": {"type": "metric", "value": "count"},
"point_unit": "square_m",
"min_radius": 1,
"row_limit": 5000,
"time_range": " : ",
"size": "count",
"time_grain_sqla": None,
"viewport": {
"bearing": -4.952916738791771,
"latitude": 37.78926922909199,
"longitude": -122.42613341901688,
"pitch": 4.750411100577438,
"zoom": 12.729132798697304,
},
'color_picker': COLOR_RED,
'datasource': '5__table',
'filters': [],
'granularity_sqla': None,
'groupby': [],
'having': '',
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'multiplier': 10,
'point_radius_fixed': {'type': 'metric', 'value': 'count'},
'point_unit': 'square_m',
'min_radius': 1,
'row_limit': 5000,
'time_range': ' : ',
'size': 'count',
'time_grain_sqla': None,
'viewport': {
'bearing': -4.952916738791771,
'latitude': 37.78926922909199,
'longitude': -122.42613341901688,
'pitch': 4.750411100577438,
'zoom': 12.729132798697304,
},
'viz_type': 'deck_scatter',
'where': '',
"viz_type": "deck_scatter",
"where": "",
}
print('Creating Scatterplot slice')
print("Creating Scatterplot slice")
slc = Slice(
slice_name='Scatterplot',
viz_type='deck_scatter',
datasource_type='table',
slice_name="Scatterplot",
viz_type="deck_scatter",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -224,46 +208,37 @@ def load_deck_dash():
slices.append(slc)
slice_data = {
'point_unit': 'square_m',
'filters': [],
'row_limit': 5000,
'spatial': {
'type': 'latlong',
'lonCol': 'LON',
'latCol': 'LAT',
"point_unit": "square_m",
"filters": [],
"row_limit": 5000,
"spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
"mapbox_style": "mapbox://styles/mapbox/dark-v9",
"granularity_sqla": None,
"size": "count",
"viz_type": "deck_screengrid",
"time_range": "No filter",
"point_radius": "Auto",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 20,
"where": "",
"having": "",
"viewport": {
"zoom": 14.161641703941438,
"longitude": -122.41827069521386,
"bearing": -4.952916738791771,
"latitude": 37.76024135844065,
"pitch": 4.750411100577438,
},
'mapbox_style': 'mapbox://styles/mapbox/dark-v9',
'granularity_sqla': None,
'size': 'count',
'viz_type': 'deck_screengrid',
'time_range': 'No filter',
'point_radius': 'Auto',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 20,
'where': '',
'having': '',
'viewport': {
'zoom': 14.161641703941438,
'longitude': -122.41827069521386,
'bearing': -4.952916738791771,
'latitude': 37.76024135844065,
'pitch': 4.750411100577438,
},
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
"point_radius_fixed": {"type": "fix", "value": 2000},
"datasource": "5__table",
"time_grain_sqla": None,
"groupby": [],
}
print('Creating Screen Grid slice')
print("Creating Screen Grid slice")
slc = Slice(
slice_name='Screen grid',
viz_type='deck_screengrid',
datasource_type='table',
slice_name="Screen grid",
viz_type="deck_screengrid",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -271,47 +246,38 @@ def load_deck_dash():
slices.append(slc)
slice_data = {
'spatial': {
'type': 'latlong',
'lonCol': 'LON',
'latCol': 'LAT',
"spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
"filters": [],
"row_limit": 5000,
"mapbox_style": "mapbox://styles/mapbox/streets-v9",
"granularity_sqla": None,
"size": "count",
"viz_type": "deck_hex",
"time_range": "No filter",
"point_radius_unit": "Pixels",
"point_radius": "Auto",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 40,
"extruded": True,
"having": "",
"viewport": {
"latitude": 37.789795085160335,
"pitch": 54.08961642447763,
"zoom": 13.835465702403654,
"longitude": -122.40632230075536,
"bearing": -2.3984797349335167,
},
'filters': [],
'row_limit': 5000,
'mapbox_style': 'mapbox://styles/mapbox/streets-v9',
'granularity_sqla': None,
'size': 'count',
'viz_type': 'deck_hex',
'time_range': 'No filter',
'point_radius_unit': 'Pixels',
'point_radius': 'Auto',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 40,
'extruded': True,
'having': '',
'viewport': {
'latitude': 37.789795085160335,
'pitch': 54.08961642447763,
'zoom': 13.835465702403654,
'longitude': -122.40632230075536,
'bearing': -2.3984797349335167,
},
'where': '',
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
"where": "",
"point_radius_fixed": {"type": "fix", "value": 2000},
"datasource": "5__table",
"time_grain_sqla": None,
"groupby": [],
}
print('Creating Hex slice')
print("Creating Hex slice")
slc = Slice(
slice_name='Hexagons',
viz_type='deck_hex',
datasource_type='table',
slice_name="Hexagons",
viz_type="deck_hex",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
@ -319,120 +285,98 @@ def load_deck_dash():
slices.append(slc)
slice_data = {
'spatial': {
'type': 'latlong',
'lonCol': 'LON',
'latCol': 'LAT',
"spatial": {"type": "latlong", "lonCol": "LON", "latCol": "LAT"},
"filters": [],
"row_limit": 5000,
"mapbox_style": "mapbox://styles/mapbox/satellite-streets-v9",
"granularity_sqla": None,
"size": "count",
"viz_type": "deck_grid",
"point_radius_unit": "Pixels",
"point_radius": "Auto",
"time_range": "No filter",
"color_picker": {"a": 1, "r": 14, "b": 0, "g": 255},
"grid_size": 120,
"extruded": True,
"having": "",
"viewport": {
"longitude": -122.42066918995666,
"bearing": 155.80099696026355,
"zoom": 12.699690845482069,
"latitude": 37.7942314882596,
"pitch": 53.470800300695146,
},
'filters': [],
'row_limit': 5000,
'mapbox_style': 'mapbox://styles/mapbox/satellite-streets-v9',
'granularity_sqla': None,
'size': 'count',
'viz_type': 'deck_grid',
'point_radius_unit': 'Pixels',
'point_radius': 'Auto',
'time_range': 'No filter',
'color_picker': {
'a': 1,
'r': 14,
'b': 0,
'g': 255,
},
'grid_size': 120,
'extruded': True,
'having': '',
'viewport': {
'longitude': -122.42066918995666,
'bearing': 155.80099696026355,
'zoom': 12.699690845482069,
'latitude': 37.7942314882596,
'pitch': 53.470800300695146,
},
'where': '',
'point_radius_fixed': {'type': 'fix', 'value': 2000},
'datasource': '5__table',
'time_grain_sqla': None,
'groupby': [],
"where": "",
"point_radius_fixed": {"type": "fix", "value": 2000},
"datasource": "5__table",
"time_grain_sqla": None,
"groupby": [],
}
print('Creating Grid slice')
print("Creating Grid slice")
slc = Slice(
slice_name='Grid',
viz_type='deck_grid',
datasource_type='table',
slice_name="Grid",
viz_type="deck_grid",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
slices.append(slc)
polygon_tbl = db.session.query(TBL) \
.filter_by(table_name='sf_population_polygons').first()
polygon_tbl = (
db.session.query(TBL).filter_by(table_name="sf_population_polygons").first()
)
slice_data = {
'datasource': '11__table',
'viz_type': 'deck_polygon',
'slice_id': 41,
'granularity_sqla': None,
'time_grain_sqla': None,
'time_range': ' : ',
'line_column': 'contour',
'metric': None,
'line_type': 'json',
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'viewport': {
'longitude': -122.43388541747726,
'latitude': 37.752020331384834,
'zoom': 11.133995608594631,
'bearing': 37.89506450385642,
'pitch': 60,
'width': 667,
'height': 906,
'altitude': 1.5,
'maxZoom': 20,
'minZoom': 0,
'maxPitch': 60,
'minPitch': 0,
'maxLatitude': 85.05113,
'minLatitude': -85.05113,
"datasource": "11__table",
"viz_type": "deck_polygon",
"slice_id": 41,
"granularity_sqla": None,
"time_grain_sqla": None,
"time_range": " : ",
"line_column": "contour",
"metric": None,
"line_type": "json",
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"viewport": {
"longitude": -122.43388541747726,
"latitude": 37.752020331384834,
"zoom": 11.133995608594631,
"bearing": 37.89506450385642,
"pitch": 60,
"width": 667,
"height": 906,
"altitude": 1.5,
"maxZoom": 20,
"minZoom": 0,
"maxPitch": 60,
"minPitch": 0,
"maxLatitude": 85.05113,
"minLatitude": -85.05113,
},
'reverse_long_lat': False,
'fill_color_picker': {
'r': 3,
'g': 65,
'b': 73,
'a': 1,
},
'stroke_color_picker': {
'r': 0,
'g': 122,
'b': 135,
'a': 1,
},
'filled': True,
'stroked': False,
'extruded': True,
'point_radius_scale': 100,
'js_columns': [
'population',
'area',
],
'js_data_mutator':
'data => data.map(d => ({\n'
' ...d,\n'
' elevation: d.extraProps.population/d.extraProps.area/10,\n'
'}));',
'js_tooltip': '',
'js_onclick_href': '',
'where': '',
'having': '',
'filters': [],
"reverse_long_lat": False,
"fill_color_picker": {"r": 3, "g": 65, "b": 73, "a": 1},
"stroke_color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
"filled": True,
"stroked": False,
"extruded": True,
"point_radius_scale": 100,
"js_columns": ["population", "area"],
"js_data_mutator": "data => data.map(d => ({\n"
" ...d,\n"
" elevation: d.extraProps.population/d.extraProps.area/10,\n"
"}));",
"js_tooltip": "",
"js_onclick_href": "",
"where": "",
"having": "",
"filters": [],
}
print('Creating Polygon slice')
print("Creating Polygon slice")
slc = Slice(
slice_name='Polygons',
viz_type='deck_polygon',
datasource_type='table',
slice_name="Polygons",
viz_type="deck_polygon",
datasource_type="table",
datasource_id=polygon_tbl.id,
params=get_slice_json(slice_data),
)
@ -440,125 +384,116 @@ def load_deck_dash():
slices.append(slc)
slice_data = {
'datasource': '10__table',
'viz_type': 'deck_arc',
'slice_id': 42,
'granularity_sqla': None,
'time_grain_sqla': None,
'time_range': ' : ',
'start_spatial': {
'type': 'latlong',
'latCol': 'LATITUDE',
'lonCol': 'LONGITUDE',
"datasource": "10__table",
"viz_type": "deck_arc",
"slice_id": 42,
"granularity_sqla": None,
"time_grain_sqla": None,
"time_range": " : ",
"start_spatial": {
"type": "latlong",
"latCol": "LATITUDE",
"lonCol": "LONGITUDE",
},
'end_spatial': {
'type': 'latlong',
'latCol': 'LATITUDE_DEST',
'lonCol': 'LONGITUDE_DEST',
"end_spatial": {
"type": "latlong",
"latCol": "LATITUDE_DEST",
"lonCol": "LONGITUDE_DEST",
},
'row_limit': 5000,
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'viewport': {
'altitude': 1.5,
'bearing': 8.546256357301871,
'height': 642,
'latitude': 44.596651438714254,
'longitude': -91.84340711201104,
'maxLatitude': 85.05113,
'maxPitch': 60,
'maxZoom': 20,
'minLatitude': -85.05113,
'minPitch': 0,
'minZoom': 0,
'pitch': 60,
'width': 997,
'zoom': 2.929837070560775,
"row_limit": 5000,
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"viewport": {
"altitude": 1.5,
"bearing": 8.546256357301871,
"height": 642,
"latitude": 44.596651438714254,
"longitude": -91.84340711201104,
"maxLatitude": 85.05113,
"maxPitch": 60,
"maxZoom": 20,
"minLatitude": -85.05113,
"minPitch": 0,
"minZoom": 0,
"pitch": 60,
"width": 997,
"zoom": 2.929837070560775,
},
'color_picker': {
'r': 0,
'g': 122,
'b': 135,
'a': 1,
},
'stroke_width': 1,
'where': '',
'having': '',
'filters': [],
"color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
"stroke_width": 1,
"where": "",
"having": "",
"filters": [],
}
print('Creating Arc slice')
print("Creating Arc slice")
slc = Slice(
slice_name='Arcs',
viz_type='deck_arc',
datasource_type='table',
datasource_id=db.session.query(TBL).filter_by(table_name='flights').first().id,
slice_name="Arcs",
viz_type="deck_arc",
datasource_type="table",
datasource_id=db.session.query(TBL).filter_by(table_name="flights").first().id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
slices.append(slc)
slice_data = {
'datasource': '12__table',
'slice_id': 43,
'viz_type': 'deck_path',
'time_grain_sqla': None,
'time_range': ' : ',
'line_column': 'path_json',
'line_type': 'json',
'row_limit': 5000,
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'viewport': {
'longitude': -122.18885402582598,
'latitude': 37.73671752604488,
'zoom': 9.51847667620428,
'bearing': 0,
'pitch': 0,
'width': 669,
'height': 1094,
'altitude': 1.5,
'maxZoom': 20,
'minZoom': 0,
'maxPitch': 60,
'minPitch': 0,
'maxLatitude': 85.05113,
'minLatitude': -85.05113,
"datasource": "12__table",
"slice_id": 43,
"viz_type": "deck_path",
"time_grain_sqla": None,
"time_range": " : ",
"line_column": "path_json",
"line_type": "json",
"row_limit": 5000,
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"viewport": {
"longitude": -122.18885402582598,
"latitude": 37.73671752604488,
"zoom": 9.51847667620428,
"bearing": 0,
"pitch": 0,
"width": 669,
"height": 1094,
"altitude": 1.5,
"maxZoom": 20,
"minZoom": 0,
"maxPitch": 60,
"minPitch": 0,
"maxLatitude": 85.05113,
"minLatitude": -85.05113,
},
'color_picker': {
'r': 0,
'g': 122,
'b': 135,
'a': 1,
},
'line_width': 150,
'reverse_long_lat': False,
'js_columns': [
'color',
],
'js_data_mutator': 'data => data.map(d => ({\n'
' ...d,\n'
' color: colors.hexToRGB(d.extraProps.color)\n'
'}));',
'js_tooltip': '',
'js_onclick_href': '',
'where': '',
'having': '',
'filters': [],
"color_picker": {"r": 0, "g": 122, "b": 135, "a": 1},
"line_width": 150,
"reverse_long_lat": False,
"js_columns": ["color"],
"js_data_mutator": "data => data.map(d => ({\n"
" ...d,\n"
" color: colors.hexToRGB(d.extraProps.color)\n"
"}));",
"js_tooltip": "",
"js_onclick_href": "",
"where": "",
"having": "",
"filters": [],
}
print('Creating Path slice')
print("Creating Path slice")
slc = Slice(
slice_name='Path',
viz_type='deck_path',
datasource_type='table',
datasource_id=db.session.query(TBL).filter_by(table_name='bart_lines').first().id,
slice_name="Path",
viz_type="deck_path",
datasource_type="table",
datasource_id=db.session.query(TBL)
.filter_by(table_name="bart_lines")
.first()
.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
slices.append(slc)
slug = 'deck'
slug = "deck"
print('Creating a dashboard')
title = 'deck.gl Demo'
print("Creating a dashboard")
title = "deck.gl Demo"
dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash:
@ -574,5 +509,5 @@ def load_deck_dash():
db.session.commit()
if __name__ == '__main__':
if __name__ == "__main__":
load_deck_dash()

View File

@ -26,51 +26,53 @@ from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.utils import core as utils
from .helpers import (
DATA_FOLDER, get_example_data, merge_slice, misc_dash_slices, Slice, TBL,
DATA_FOLDER,
get_example_data,
merge_slice,
misc_dash_slices,
Slice,
TBL,
)
def load_energy():
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = 'energy_usage'
data = get_example_data('energy.json.gz')
tbl_name = "energy_usage"
data = get_example_data("energy.json.gz")
pdf = pd.read_json(data)
pdf.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'source': String(255),
'target': String(255),
'value': Float(),
},
index=False)
dtype={"source": String(255), "target": String(255), "value": Float()},
index=False,
)
print('Creating table [wb_health_population] reference')
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = 'Energy consumption'
tbl.description = "Energy consumption"
tbl.database = utils.get_or_create_main_db()
if not any(col.metric_name == 'sum__value' for col in tbl.metrics):
col = str(column('value').compile(db.engine))
tbl.metrics.append(SqlMetric(
metric_name='sum__value',
expression=f'SUM({col})',
))
if not any(col.metric_name == "sum__value" for col in tbl.metrics):
col = str(column("value").compile(db.engine))
tbl.metrics.append(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
slc = Slice(
slice_name='Energy Sankey',
viz_type='sankey',
datasource_type='table',
slice_name="Energy Sankey",
viz_type="sankey",
datasource_type="table",
datasource_id=tbl.id,
params=textwrap.dedent("""\
params=textwrap.dedent(
"""\
{
"collapsed_fieldsets": "",
"groupby": [
@ -84,17 +86,19 @@ def load_energy():
"viz_type": "sankey",
"where": ""
}
"""),
"""
),
)
misc_dash_slices.add(slc.slice_name)
merge_slice(slc)
slc = Slice(
slice_name='Energy Force Layout',
viz_type='directed_force',
datasource_type='table',
slice_name="Energy Force Layout",
viz_type="directed_force",
datasource_type="table",
datasource_id=tbl.id,
params=textwrap.dedent("""\
params=textwrap.dedent(
"""\
{
"charge": "-500",
"collapsed_fieldsets": "",
@ -110,17 +114,19 @@ def load_energy():
"viz_type": "directed_force",
"where": ""
}
"""),
"""
),
)
misc_dash_slices.add(slc.slice_name)
merge_slice(slc)
slc = Slice(
slice_name='Heatmap',
viz_type='heatmap',
datasource_type='table',
slice_name="Heatmap",
viz_type="heatmap",
datasource_type="table",
datasource_id=tbl.id,
params=textwrap.dedent("""\
params=textwrap.dedent(
"""\
{
"all_columns_x": "source",
"all_columns_y": "target",
@ -136,7 +142,8 @@ def load_energy():
"xscale_interval": "1",
"yscale_interval": "1"
}
"""),
"""
),
)
misc_dash_slices.add(slc.slice_name)
merge_slice(slc)

View File

@ -24,38 +24,37 @@ from .helpers import get_example_data, TBL
def load_flights():
"""Loading random time series data from a zip file in the repo"""
tbl_name = 'flights'
data = get_example_data('flight_data.csv.gz', make_bytes=True)
pdf = pd.read_csv(data, encoding='latin-1')
tbl_name = "flights"
data = get_example_data("flight_data.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="latin-1")
# Loading airports info to join and get lat/long
airports_bytes = get_example_data('airports.csv.gz', make_bytes=True)
airports = pd.read_csv(airports_bytes, encoding='latin-1')
airports = airports.set_index('IATA_CODE')
airports_bytes = get_example_data("airports.csv.gz", make_bytes=True)
airports = pd.read_csv(airports_bytes, encoding="latin-1")
airports = airports.set_index("IATA_CODE")
pdf['ds'] = pdf.YEAR.map(str) + '-0' + pdf.MONTH.map(str) + '-0' + pdf.DAY.map(str)
pdf["ds"] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)
pdf.ds = pd.to_datetime(pdf.ds)
del pdf['YEAR']
del pdf['MONTH']
del pdf['DAY']
del pdf["YEAR"]
del pdf["MONTH"]
del pdf["DAY"]
pdf = pdf.join(airports, on='ORIGIN_AIRPORT', rsuffix='_ORIG')
pdf = pdf.join(airports, on='DESTINATION_AIRPORT', rsuffix='_DEST')
pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG")
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'ds': DateTime,
},
index=False)
dtype={"ds": DateTime},
index=False,
)
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = 'Random set of flights in the US'
tbl.description = "Random set of flights in the US"
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
print('Done loading table!')
print("Done loading table!")

View File

@ -27,31 +27,32 @@ from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models import core as models
BASE_URL = 'https://github.com/apache-superset/examples-data/blob/master/'
BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
# Shortcuts
DB = models.Database
Slice = models.Slice
Dash = models.Dashboard
TBL = ConnectorRegistry.sources['table']
TBL = ConnectorRegistry.sources["table"]
config = app.config
DATA_FOLDER = os.path.join(config.get('BASE_DIR'), 'data')
DATA_FOLDER = os.path.join(config.get("BASE_DIR"), "data")
misc_dash_slices = set() # slices assembled in a 'Misc Chart' dashboard
def update_slice_ids(layout_dict, slices):
charts = [
component for component in layout_dict.values()
if isinstance(component, dict) and component['type'] == 'CHART'
component
for component in layout_dict.values()
if isinstance(component, dict) and component["type"] == "CHART"
]
sorted_charts = sorted(charts, key=lambda k: k['meta']['chartId'])
sorted_charts = sorted(charts, key=lambda k: k["meta"]["chartId"])
for i, chart_component in enumerate(sorted_charts):
if i < len(slices):
chart_component['meta']['chartId'] = int(slices[i].id)
chart_component["meta"]["chartId"] = int(slices[i].id)
def merge_slice(slc):
@ -69,9 +70,9 @@ def get_slice_json(defaults, **kwargs):
def get_example_data(filepath, is_gzip=True, make_bytes=False):
content = requests.get(f'{BASE_URL}{filepath}?raw=true').content
content = requests.get(f"{BASE_URL}{filepath}?raw=true").content
if is_gzip:
content = zlib.decompress(content, zlib.MAX_WBITS|16)
content = zlib.decompress(content, zlib.MAX_WBITS | 16)
if make_bytes:
content = BytesIO(content)
return content

View File

@ -35,50 +35,49 @@ from .helpers import (
def load_long_lat_data():
"""Loading lat/long data from a csv file in the repo"""
data = get_example_data('san_francisco.csv.gz', make_bytes=True)
pdf = pd.read_csv(data, encoding='utf-8')
start = datetime.datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0)
pdf['datetime'] = [
data = get_example_data("san_francisco.csv.gz", make_bytes=True)
pdf = pd.read_csv(data, encoding="utf-8")
start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
pdf["datetime"] = [
start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1))
for i in range(len(pdf))
]
pdf['occupancy'] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf['radius_miles'] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf['geohash'] = pdf[['LAT', 'LON']].apply(
lambda x: geohash.encode(*x), axis=1)
pdf['delimited'] = pdf['LAT'].map(str).str.cat(pdf['LON'].map(str), sep=',')
pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))]
pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))]
pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1)
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql( # pylint: disable=no-member
'long_lat',
"long_lat",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'longitude': Float(),
'latitude': Float(),
'number': Float(),
'street': String(100),
'unit': String(10),
'city': String(50),
'district': String(50),
'region': String(50),
'postcode': Float(),
'id': String(100),
'datetime': DateTime(),
'occupancy': Float(),
'radius_miles': Float(),
'geohash': String(12),
'delimited': String(60),
"longitude": Float(),
"latitude": Float(),
"number": Float(),
"street": String(100),
"unit": String(10),
"city": String(50),
"district": String(50),
"region": String(50),
"postcode": Float(),
"id": String(100),
"datetime": DateTime(),
"occupancy": Float(),
"radius_miles": Float(),
"geohash": String(12),
"delimited": String(60),
},
index=False)
print('Done loading table!')
print('-' * 80)
index=False,
)
print("Done loading table!")
print("-" * 80)
print('Creating table reference')
obj = db.session.query(TBL).filter_by(table_name='long_lat').first()
print("Creating table reference")
obj = db.session.query(TBL).filter_by(table_name="long_lat").first()
if not obj:
obj = TBL(table_name='long_lat')
obj.main_dttm_col = 'datetime'
obj = TBL(table_name="long_lat")
obj.main_dttm_col = "datetime"
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
@ -86,23 +85,23 @@ def load_long_lat_data():
tbl = obj
slice_data = {
'granularity_sqla': 'day',
'since': '2014-01-01',
'until': 'now',
'where': '',
'viz_type': 'mapbox',
'all_columns_x': 'LON',
'all_columns_y': 'LAT',
'mapbox_style': 'mapbox://styles/mapbox/light-v9',
'all_columns': ['occupancy'],
'row_limit': 500000,
"granularity_sqla": "day",
"since": "2014-01-01",
"until": "now",
"where": "",
"viz_type": "mapbox",
"all_columns_x": "LON",
"all_columns_y": "LAT",
"mapbox_style": "mapbox://styles/mapbox/light-v9",
"all_columns": ["occupancy"],
"row_limit": 500000,
}
print('Creating a slice')
print("Creating a slice")
slc = Slice(
slice_name='Mapbox Long/Lat',
viz_type='mapbox',
datasource_type='table',
slice_name="Mapbox Long/Lat",
viz_type="mapbox",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -19,26 +19,22 @@ import textwrap
from superset import db
from .helpers import (
Dash,
misc_dash_slices,
Slice,
update_slice_ids,
)
from .helpers import Dash, misc_dash_slices, Slice, update_slice_ids
DASH_SLUG = 'misc_charts'
DASH_SLUG = "misc_charts"
def load_misc_dashboard():
"""Loading a dashboard featuring misc charts"""
print('Creating the dashboard')
print("Creating the dashboard")
db.session.expunge_all()
dash = db.session.query(Dash).filter_by(slug=DASH_SLUG).first()
if not dash:
dash = Dash()
js = textwrap.dedent("""\
js = textwrap.dedent(
"""\
{
"CHART-BkeVbh8ANQ": {
"children": [],
@ -210,17 +206,15 @@ def load_misc_dashboard():
},
"DASHBOARD_VERSION_KEY": "v2"
}
""")
"""
)
pos = json.loads(js)
slices = (
db.session
.query(Slice)
.filter(Slice.slice_name.in_(misc_dash_slices))
.all()
db.session.query(Slice).filter(Slice.slice_name.in_(misc_dash_slices)).all()
)
slices = sorted(slices, key=lambda x: x.id)
update_slice_ids(pos, slices)
dash.dashboard_title = 'Misc Charts'
dash.dashboard_title = "Misc Charts"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG
dash.slices = slices

View File

@ -18,11 +18,7 @@ import json
from superset import db
from .birth_names import load_birth_names
from .helpers import (
merge_slice,
misc_dash_slices,
Slice,
)
from .helpers import merge_slice, misc_dash_slices, Slice
from .world_bank import load_world_bank_health_n_pop
@ -30,27 +26,30 @@ def load_multi_line():
load_world_bank_health_n_pop()
load_birth_names()
ids = [
row.id for row in
db.session.query(Slice).filter(
Slice.slice_name.in_(['Growth Rate', 'Trends']))
row.id
for row in db.session.query(Slice).filter(
Slice.slice_name.in_(["Growth Rate", "Trends"])
)
]
slc = Slice(
datasource_type='table', # not true, but needed
datasource_id=1, # cannot be empty
slice_name='Multi Line',
viz_type='line_multi',
params=json.dumps({
'slice_name': 'Multi Line',
'viz_type': 'line_multi',
'line_charts': [ids[0]],
'line_charts_2': [ids[1]],
'since': '1970',
'until': '1995',
'prefix_metric_with_slice_name': True,
'show_legend': False,
'x_axis_format': '%Y',
}),
datasource_type="table", # not true, but needed
datasource_id=1, # cannot be empty
slice_name="Multi Line",
viz_type="line_multi",
params=json.dumps(
{
"slice_name": "Multi Line",
"viz_type": "line_multi",
"line_charts": [ids[0]],
"line_charts_2": [ids[1]],
"since": "1970",
"until": "1995",
"prefix_metric_with_slice_name": True,
"show_legend": False,
"x_axis_format": "%Y",
}
),
)
misc_dash_slices.add(slc.slice_name)

View File

@ -33,44 +33,45 @@ from .helpers import (
def load_multiformat_time_series():
"""Loading time series data from a zip file in the repo"""
data = get_example_data('multiformat_time_series.json.gz')
data = get_example_data("multiformat_time_series.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='s')
pdf.ds2 = pd.to_datetime(pdf.ds2, unit='s')
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s")
pdf.to_sql(
'multiformat_time_series',
"multiformat_time_series",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'ds': Date,
'ds2': DateTime,
'epoch_s': BigInteger,
'epoch_ms': BigInteger,
'string0': String(100),
'string1': String(100),
'string2': String(100),
'string3': String(100),
"ds": Date,
"ds2": DateTime,
"epoch_s": BigInteger,
"epoch_ms": BigInteger,
"string0": String(100),
"string1": String(100),
"string2": String(100),
"string3": String(100),
},
index=False)
print('Done loading table!')
print('-' * 80)
print('Creating table [multiformat_time_series] reference')
obj = db.session.query(TBL).filter_by(table_name='multiformat_time_series').first()
index=False,
)
print("Done loading table!")
print("-" * 80)
print("Creating table [multiformat_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name="multiformat_time_series").first()
if not obj:
obj = TBL(table_name='multiformat_time_series')
obj.main_dttm_col = 'ds'
obj = TBL(table_name="multiformat_time_series")
obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db()
dttm_and_expr_dict = {
'ds': [None, None],
'ds2': [None, None],
'epoch_s': ['epoch_s', None],
'epoch_ms': ['epoch_ms', None],
'string2': ['%Y%m%d-%H%M%S', None],
'string1': ['%Y-%m-%d^%H:%M:%S', None],
'string0': ['%Y-%m-%d %H:%M:%S.%f', None],
'string3': ['%Y/%m/%d%H:%M:%S.%f', None],
"ds": [None, None],
"ds2": [None, None],
"epoch_s": ["epoch_s", None],
"epoch_ms": ["epoch_ms", None],
"string2": ["%Y%m%d-%H%M%S", None],
"string1": ["%Y-%m-%d^%H:%M:%S", None],
"string0": ["%Y-%m-%d %H:%M:%S.%f", None],
"string3": ["%Y/%m/%d%H:%M:%S.%f", None],
}
for col in obj.columns:
dttm_and_expr = dttm_and_expr_dict[col.column_name]
@ -82,26 +83,26 @@ def load_multiformat_time_series():
obj.fetch_metadata()
tbl = obj
print('Creating Heatmap charts')
print("Creating Heatmap charts")
for i, col in enumerate(tbl.columns):
slice_data = {
'metrics': ['count'],
'granularity_sqla': col.column_name,
'row_limit': config.get('ROW_LIMIT'),
'since': '2015',
'until': '2016',
'where': '',
'viz_type': 'cal_heatmap',
'domain_granularity': 'month',
'subdomain_granularity': 'day',
"metrics": ["count"],
"granularity_sqla": col.column_name,
"row_limit": config.get("ROW_LIMIT"),
"since": "2015",
"until": "2016",
"where": "",
"viz_type": "cal_heatmap",
"domain_granularity": "month",
"subdomain_granularity": "day",
}
slc = Slice(
slice_name=f'Calendar Heatmap multiformat {i}',
viz_type='cal_heatmap',
datasource_type='table',
slice_name=f"Calendar Heatmap multiformat {i}",
viz_type="cal_heatmap",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
misc_dash_slices.add('Calendar Heatmap multiformat 0')
misc_dash_slices.add("Calendar Heatmap multiformat 0")

View File

@ -25,29 +25,30 @@ from .helpers import TBL, get_example_data
def load_paris_iris_geojson():
tbl_name = 'paris_iris_mapping'
tbl_name = "paris_iris_mapping"
data = get_example_data('paris_iris.json.gz')
data = get_example_data("paris_iris.json.gz")
df = pd.read_json(data)
df['features'] = df.features.map(json.dumps)
df["features"] = df.features.map(json.dumps)
df.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'color': String(255),
'name': String(255),
'features': Text,
'type': Text,
"color": String(255),
"name": String(255),
"features": Text,
"type": Text,
},
index=False)
print('Creating table {} reference'.format(tbl_name))
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = 'Map of Paris'
tbl.description = "Map of Paris"
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()

View File

@ -20,38 +20,30 @@ from sqlalchemy import DateTime
from superset import db
from superset.utils import core as utils
from .helpers import (
config,
get_example_data,
get_slice_json,
merge_slice,
Slice,
TBL,
)
from .helpers import config, get_example_data, get_slice_json, merge_slice, Slice, TBL
def load_random_time_series_data():
"""Loading random time series data from a zip file in the repo"""
data = get_example_data('random_time_series.json.gz')
data = get_example_data("random_time_series.json.gz")
pdf = pd.read_json(data)
pdf.ds = pd.to_datetime(pdf.ds, unit='s')
pdf.ds = pd.to_datetime(pdf.ds, unit="s")
pdf.to_sql(
'random_time_series',
"random_time_series",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'ds': DateTime,
},
index=False)
print('Done loading table!')
print('-' * 80)
dtype={"ds": DateTime},
index=False,
)
print("Done loading table!")
print("-" * 80)
print('Creating table [random_time_series] reference')
obj = db.session.query(TBL).filter_by(table_name='random_time_series').first()
print("Creating table [random_time_series] reference")
obj = db.session.query(TBL).filter_by(table_name="random_time_series").first()
if not obj:
obj = TBL(table_name='random_time_series')
obj.main_dttm_col = 'ds'
obj = TBL(table_name="random_time_series")
obj.main_dttm_col = "ds"
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
@ -59,22 +51,22 @@ def load_random_time_series_data():
tbl = obj
slice_data = {
'granularity_sqla': 'day',
'row_limit': config.get('ROW_LIMIT'),
'since': '1 year ago',
'until': 'now',
'metric': 'count',
'where': '',
'viz_type': 'cal_heatmap',
'domain_granularity': 'month',
'subdomain_granularity': 'day',
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
"since": "1 year ago",
"until": "now",
"metric": "count",
"where": "",
"viz_type": "cal_heatmap",
"domain_granularity": "month",
"subdomain_granularity": "day",
}
print('Creating a slice')
print("Creating a slice")
slc = Slice(
slice_name='Calendar Heatmap',
viz_type='cal_heatmap',
datasource_type='table',
slice_name="Calendar Heatmap",
viz_type="cal_heatmap",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)

View File

@ -25,29 +25,30 @@ from .helpers import TBL, get_example_data
def load_sf_population_polygons():
tbl_name = 'sf_population_polygons'
tbl_name = "sf_population_polygons"
data = get_example_data('sf_population.json.gz')
data = get_example_data("sf_population.json.gz")
df = pd.read_json(data)
df['contour'] = df.contour.map(json.dumps)
df["contour"] = df.contour.map(json.dumps)
df.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'zipcode': BigInteger,
'population': BigInteger,
'contour': Text,
'area': BigInteger,
"zipcode": BigInteger,
"population": BigInteger,
"contour": Text,
"area": BigInteger,
},
index=False)
print('Creating table {} reference'.format(tbl_name))
index=False,
)
print("Creating table {} reference".format(tbl_name))
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = 'Population density of San Francisco'
tbl.description = "Population density of San Francisco"
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()

View File

@ -44,7 +44,7 @@ def load_tabbed_dashboard():
"""Creating a tabbed dashboard"""
print("Creating a dashboard with nested tabs")
slug = 'tabbed_dash'
slug = "tabbed_dash"
dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash:
@ -53,12 +53,13 @@ def load_tabbed_dashboard():
# reuse charts in "World's Bank Data and create
# new dashboard with nested tabs
tabbed_dash_slices = set()
tabbed_dash_slices.add('Region Filter')
tabbed_dash_slices.add('Growth Rate')
tabbed_dash_slices.add('Treemap')
tabbed_dash_slices.add('Box plot')
tabbed_dash_slices.add("Region Filter")
tabbed_dash_slices.add("Growth Rate")
tabbed_dash_slices.add("Treemap")
tabbed_dash_slices.add("Box plot")
js = textwrap.dedent("""\
js = textwrap.dedent(
"""\
{
"CHART-c0EjR-OZ0n": {
"children": [],
@ -337,12 +338,11 @@ def load_tabbed_dashboard():
"type": "TABS"
}
}
""")
"""
)
pos = json.loads(js)
slices = [
db.session.query(Slice)
.filter_by(slice_name=name)
.first()
db.session.query(Slice).filter_by(slice_name=name).first()
for name in tabbed_dash_slices
]
@ -350,7 +350,7 @@ def load_tabbed_dashboard():
update_slice_ids(pos, slices)
dash.position_json = json.dumps(pos, indent=4)
dash.slices = slices
dash.dashboard_title = 'Tabbed Dashboard'
dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug
db.session.merge(dash)

View File

@ -38,32 +38,34 @@ from .helpers import (
def load_unicode_test_data():
"""Loading unicode test dataset from a csv file in the repo"""
data = get_example_data(
'unicode_utf8_unixnl_test.csv', is_gzip=False, make_bytes=True)
df = pd.read_csv(data, encoding='utf-8')
"unicode_utf8_unixnl_test.csv", is_gzip=False, make_bytes=True
)
df = pd.read_csv(data, encoding="utf-8")
# generate date/numeric data
df['dttm'] = datetime.datetime.now().date()
df['value'] = [random.randint(1, 100) for _ in range(len(df))]
df["dttm"] = datetime.datetime.now().date()
df["value"] = [random.randint(1, 100) for _ in range(len(df))]
df.to_sql( # pylint: disable=no-member
'unicode_test',
"unicode_test",
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=500,
dtype={
'phrase': String(500),
'short_phrase': String(10),
'with_missing': String(100),
'dttm': Date(),
'value': Float(),
"phrase": String(500),
"short_phrase": String(10),
"with_missing": String(100),
"dttm": Date(),
"value": Float(),
},
index=False)
print('Done loading table!')
print('-' * 80)
index=False,
)
print("Done loading table!")
print("-" * 80)
print('Creating table [unicode_test] reference')
obj = db.session.query(TBL).filter_by(table_name='unicode_test').first()
print("Creating table [unicode_test] reference")
obj = db.session.query(TBL).filter_by(table_name="unicode_test").first()
if not obj:
obj = TBL(table_name='unicode_test')
obj.main_dttm_col = 'dttm'
obj = TBL(table_name="unicode_test")
obj.main_dttm_col = "dttm"
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
@ -71,37 +73,33 @@ def load_unicode_test_data():
tbl = obj
slice_data = {
'granularity_sqla': 'dttm',
'groupby': [],
'metric': 'sum__value',
'row_limit': config.get('ROW_LIMIT'),
'since': '100 years ago',
'until': 'now',
'where': '',
'viz_type': 'word_cloud',
'size_from': '10',
'series': 'short_phrase',
'size_to': '70',
'rotation': 'square',
'limit': '100',
"granularity_sqla": "dttm",
"groupby": [],
"metric": "sum__value",
"row_limit": config.get("ROW_LIMIT"),
"since": "100 years ago",
"until": "now",
"where": "",
"viz_type": "word_cloud",
"size_from": "10",
"series": "short_phrase",
"size_to": "70",
"rotation": "square",
"limit": "100",
}
print('Creating a slice')
print("Creating a slice")
slc = Slice(
slice_name='Unicode Cloud',
viz_type='word_cloud',
datasource_type='table',
slice_name="Unicode Cloud",
viz_type="word_cloud",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(slice_data),
)
merge_slice(slc)
print('Creating a dashboard')
dash = (
db.session.query(Dash)
.filter_by(dashboard_title='Unicode Test')
.first()
)
print("Creating a dashboard")
dash = db.session.query(Dash).filter_by(dashboard_title="Unicode Test").first()
if not dash:
dash = Dash()
@ -145,11 +143,11 @@ def load_unicode_test_data():
"DASHBOARD_VERSION_KEY": "v2"
}
"""
dash.dashboard_title = 'Unicode Test'
dash.dashboard_title = "Unicode Test"
pos = json.loads(js)
update_slice_ids(pos, [slc])
dash.position_json = json.dumps(pos, indent=4)
dash.slug = 'unicode-test'
dash.slug = "unicode-test"
dash.slices = [slc]
db.session.merge(dash)
db.session.commit()

View File

@ -43,232 +43,270 @@ from .helpers import (
def load_world_bank_health_n_pop():
"""Loads the world bank health dataset, slices and a dashboard"""
tbl_name = 'wb_health_population'
data = get_example_data('countries.json.gz')
tbl_name = "wb_health_population"
data = get_example_data("countries.json.gz")
pdf = pd.read_json(data)
pdf.columns = [col.replace('.', '_') for col in pdf.columns]
pdf.columns = [col.replace(".", "_") for col in pdf.columns]
pdf.year = pd.to_datetime(pdf.year)
pdf.to_sql(
tbl_name,
db.engine,
if_exists='replace',
if_exists="replace",
chunksize=50,
dtype={
'year': DateTime(),
'country_code': String(3),
'country_name': String(255),
'region': String(255),
"year": DateTime(),
"country_code": String(3),
"country_name": String(255),
"region": String(255),
},
index=False)
index=False,
)
print('Creating table [wb_health_population] reference')
print("Creating table [wb_health_population] reference")
tbl = db.session.query(TBL).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md'))
tbl.main_dttm_col = 'year'
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, "countries.md"))
tbl.main_dttm_col = "year"
tbl.database = utils.get_or_create_main_db()
tbl.filter_select_enabled = True
metrics = [
'sum__SP_POP_TOTL', 'sum__SH_DYN_AIDS', 'sum__SH_DYN_AIDS',
'sum__SP_RUR_TOTL_ZS', 'sum__SP_DYN_LE00_IN', 'sum__SP_RUR_TOTL'
"sum__SP_POP_TOTL",
"sum__SH_DYN_AIDS",
"sum__SH_DYN_AIDS",
"sum__SP_RUR_TOTL_ZS",
"sum__SP_DYN_LE00_IN",
"sum__SP_RUR_TOTL",
]
for m in metrics:
if not any(col.metric_name == m for col in tbl.metrics):
aggr_func = m[:3]
col = str(column(m[5:]).compile(db.engine))
tbl.metrics.append(SqlMetric(
metric_name=m,
expression=f'{aggr_func}({col})',
))
tbl.metrics.append(
SqlMetric(metric_name=m, expression=f"{aggr_func}({col})")
)
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
defaults = {
'compare_lag': '10',
'compare_suffix': 'o10Y',
'limit': '25',
'granularity_sqla': 'year',
'groupby': [],
'metric': 'sum__SP_POP_TOTL',
'metrics': ['sum__SP_POP_TOTL'],
'row_limit': config.get('ROW_LIMIT'),
'since': '2014-01-01',
'until': '2014-01-02',
'time_range': '2014-01-01 : 2014-01-02',
'where': '',
'markup_type': 'markdown',
'country_fieldtype': 'cca3',
'secondary_metric': 'sum__SP_POP_TOTL',
'entity': 'country_code',
'show_bubbles': True,
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "year",
"groupby": [],
"metric": "sum__SP_POP_TOTL",
"metrics": ["sum__SP_POP_TOTL"],
"row_limit": config.get("ROW_LIMIT"),
"since": "2014-01-01",
"until": "2014-01-02",
"time_range": "2014-01-01 : 2014-01-02",
"where": "",
"markup_type": "markdown",
"country_fieldtype": "cca3",
"secondary_metric": "sum__SP_POP_TOTL",
"entity": "country_code",
"show_bubbles": True,
}
print('Creating slices')
print("Creating slices")
slices = [
Slice(
slice_name='Region Filter',
viz_type='filter_box',
datasource_type='table',
slice_name="Region Filter",
viz_type="filter_box",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='filter_box',
viz_type="filter_box",
date_filter=False,
filter_configs=[
{
'asc': False,
'clearable': True,
'column': 'region',
'key': '2s98dfu',
'metric': 'sum__SP_POP_TOTL',
'multiple': True,
}, {
'asc': False,
'clearable': True,
'key': 'li3j2lk',
'column': 'country_name',
'metric': 'sum__SP_POP_TOTL',
'multiple': True,
"asc": False,
"clearable": True,
"column": "region",
"key": "2s98dfu",
"metric": "sum__SP_POP_TOTL",
"multiple": True,
},
])),
{
"asc": False,
"clearable": True,
"key": "li3j2lk",
"column": "country_name",
"metric": "sum__SP_POP_TOTL",
"multiple": True,
},
],
),
),
Slice(
slice_name="World's Population",
viz_type='big_number',
datasource_type='table',
viz_type="big_number",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='2000',
viz_type='big_number',
compare_lag='10',
metric='sum__SP_POP_TOTL',
compare_suffix='over 10Y')),
since="2000",
viz_type="big_number",
compare_lag="10",
metric="sum__SP_POP_TOTL",
compare_suffix="over 10Y",
),
),
Slice(
slice_name='Most Populated Countries',
viz_type='table',
datasource_type='table',
slice_name="Most Populated Countries",
viz_type="table",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='table',
metrics=['sum__SP_POP_TOTL'],
groupby=['country_name'])),
viz_type="table",
metrics=["sum__SP_POP_TOTL"],
groupby=["country_name"],
),
),
Slice(
slice_name='Growth Rate',
viz_type='line',
datasource_type='table',
slice_name="Growth Rate",
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='line',
since='1960-01-01',
metrics=['sum__SP_POP_TOTL'],
num_period_compare='10',
groupby=['country_name'])),
viz_type="line",
since="1960-01-01",
metrics=["sum__SP_POP_TOTL"],
num_period_compare="10",
groupby=["country_name"],
),
),
Slice(
slice_name='% Rural',
viz_type='world_map',
datasource_type='table',
slice_name="% Rural",
viz_type="world_map",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='world_map',
metric='sum__SP_RUR_TOTL_ZS',
num_period_compare='10')),
viz_type="world_map",
metric="sum__SP_RUR_TOTL_ZS",
num_period_compare="10",
),
),
Slice(
slice_name='Life Expectancy VS Rural %',
viz_type='bubble',
datasource_type='table',
slice_name="Life Expectancy VS Rural %",
viz_type="bubble",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='bubble',
since='2011-01-01',
until='2011-01-02',
series='region',
viz_type="bubble",
since="2011-01-01",
until="2011-01-02",
series="region",
limit=0,
entity='country_name',
x='sum__SP_RUR_TOTL_ZS',
y='sum__SP_DYN_LE00_IN',
size='sum__SP_POP_TOTL',
max_bubble_size='50',
filters=[{
'col': 'country_code',
'val': [
'TCA', 'MNP', 'DMA', 'MHL', 'MCO', 'SXM', 'CYM',
'TUV', 'IMY', 'KNA', 'ASM', 'ADO', 'AMA', 'PLW',
],
'op': 'not in'}],
)),
entity="country_name",
x="sum__SP_RUR_TOTL_ZS",
y="sum__SP_DYN_LE00_IN",
size="sum__SP_POP_TOTL",
max_bubble_size="50",
filters=[
{
"col": "country_code",
"val": [
"TCA",
"MNP",
"DMA",
"MHL",
"MCO",
"SXM",
"CYM",
"TUV",
"IMY",
"KNA",
"ASM",
"ADO",
"AMA",
"PLW",
],
"op": "not in",
}
],
),
),
Slice(
slice_name='Rural Breakdown',
viz_type='sunburst',
datasource_type='table',
slice_name="Rural Breakdown",
viz_type="sunburst",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
viz_type='sunburst',
groupby=['region', 'country_name'],
secondary_metric='sum__SP_RUR_TOTL',
since='2011-01-01',
until='2011-01-01')),
viz_type="sunburst",
groupby=["region", "country_name"],
secondary_metric="sum__SP_RUR_TOTL",
since="2011-01-01",
until="2011-01-01",
),
),
Slice(
slice_name="World's Pop Growth",
viz_type='area',
datasource_type='table',
viz_type="area",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='1960-01-01',
until='now',
viz_type='area',
groupby=['region'])),
since="1960-01-01",
until="now",
viz_type="area",
groupby=["region"],
),
),
Slice(
slice_name='Box plot',
viz_type='box_plot',
datasource_type='table',
slice_name="Box plot",
viz_type="box_plot",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='1960-01-01',
until='now',
whisker_options='Min/max (no outliers)',
x_ticks_layout='staggered',
viz_type='box_plot',
groupby=['region'])),
since="1960-01-01",
until="now",
whisker_options="Min/max (no outliers)",
x_ticks_layout="staggered",
viz_type="box_plot",
groupby=["region"],
),
),
Slice(
slice_name='Treemap',
viz_type='treemap',
datasource_type='table',
slice_name="Treemap",
viz_type="treemap",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='1960-01-01',
until='now',
viz_type='treemap',
metrics=['sum__SP_POP_TOTL'],
groupby=['region', 'country_code'])),
since="1960-01-01",
until="now",
viz_type="treemap",
metrics=["sum__SP_POP_TOTL"],
groupby=["region", "country_code"],
),
),
Slice(
slice_name='Parallel Coordinates',
viz_type='para',
datasource_type='table',
slice_name="Parallel Coordinates",
viz_type="para",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults,
since='2011-01-01',
until='2011-01-01',
viz_type='para',
since="2011-01-01",
until="2011-01-01",
viz_type="para",
limit=100,
metrics=[
'sum__SP_POP_TOTL',
'sum__SP_RUR_TOTL_ZS',
'sum__SH_DYN_AIDS'],
secondary_metric='sum__SP_POP_TOTL',
series='country_name')),
metrics=["sum__SP_POP_TOTL", "sum__SP_RUR_TOTL_ZS", "sum__SH_DYN_AIDS"],
secondary_metric="sum__SP_POP_TOTL",
series="country_name",
),
),
]
misc_dash_slices.add(slices[-1].slice_name)
for slc in slices:
@ -276,12 +314,13 @@ def load_world_bank_health_n_pop():
print("Creating a World's Health Bank dashboard")
dash_name = "World's Bank Data"
slug = 'world_health'
slug = "world_health"
dash = db.session.query(Dash).filter_by(slug=slug).first()
if not dash:
dash = Dash()
js = textwrap.dedent("""\
js = textwrap.dedent(
"""\
{
"CHART-36bfc934": {
"children": [],
@ -497,7 +536,8 @@ def load_world_bank_health_n_pop():
},
"DASHBOARD_VERSION_KEY": "v2"
}
""")
"""
)
pos = json.loads(js)
update_slice_ids(pos, slices)

View File

@ -36,7 +36,7 @@ INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100
def dedup(l, suffix='__', case_sensitive=True):
def dedup(l, suffix="__", case_sensitive=True):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
@ -44,7 +44,9 @@ def dedup(l, suffix='__', case_sensitive=True):
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
foo,bar,bar__1,bar__2,Bar
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
>>> print(
','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))
)
foo,bar,bar__1,bar__2,Bar__3
"""
new_l = []
@ -63,18 +65,18 @@ def dedup(l, suffix='__', case_sensitive=True):
class SupersetDataFrame(object):
# Mapping numpy dtype.char to generic database types
type_map = {
'b': 'BOOL', # boolean
'i': 'INT', # (signed) integer
'u': 'INT', # unsigned integer
'l': 'INT', # 64bit integer
'f': 'FLOAT', # floating-point
'c': 'FLOAT', # complex-floating point
'm': None, # timedelta
'M': 'DATETIME', # datetime
'O': 'OBJECT', # (Python) objects
'S': 'BYTE', # (byte-)string
'U': 'STRING', # Unicode
'V': None, # raw data (void)
"b": "BOOL", # boolean
"i": "INT", # (signed) integer
"u": "INT", # unsigned integer
"l": "INT", # 64bit integer
"f": "FLOAT", # floating-point
"c": "FLOAT", # complex-floating point
"m": None, # timedelta
"M": "DATETIME", # datetime
"O": "OBJECT", # (Python) objects
"S": "BYTE", # (byte-)string
"U": "STRING", # Unicode
"V": None, # raw data (void)
}
def __init__(self, data, cursor_description, db_engine_spec):
@ -85,8 +87,7 @@ class SupersetDataFrame(object):
self.column_names = dedup(column_names)
data = data or []
self.df = (
pd.DataFrame(list(data), columns=self.column_names).infer_objects())
self.df = pd.DataFrame(list(data), columns=self.column_names).infer_objects()
self._type_dict = {}
try:
@ -106,9 +107,13 @@ class SupersetDataFrame(object):
@property
def data(self):
# work around for https://github.com/pandas-dev/pandas/issues/18372
data = [dict((k, _maybe_box_datetimelike(v))
for k, v in zip(self.df.columns, np.atleast_1d(row)))
for row in self.df.values]
data = [
dict(
(k, _maybe_box_datetimelike(v))
for k, v in zip(self.df.columns, np.atleast_1d(row))
)
for row in self.df.values
]
for d in data:
for k, v in list(d.items()):
# if an int is too big for Java Script to handle
@ -123,7 +128,7 @@ class SupersetDataFrame(object):
"""Given a numpy dtype, Returns a generic database type"""
if isinstance(dtype, ExtensionDtype):
return cls.type_map.get(dtype.kind)
elif hasattr(dtype, 'char'):
elif hasattr(dtype, "char"):
return cls.type_map.get(dtype.char)
@classmethod
@ -141,10 +146,9 @@ class SupersetDataFrame(object):
@staticmethod
def is_date(np_dtype, db_type_str):
def looks_daty(s):
if isinstance(s, str):
return any([s.lower().startswith(ss) for ss in ('time', 'date')])
return any([s.lower().startswith(ss) for ss in ("time", "date")])
return False
if looks_daty(db_type_str):
@ -157,20 +161,23 @@ class SupersetDataFrame(object):
def is_dimension(cls, dtype, column_name):
if cls.is_id(column_name):
return False
return dtype.name in ('object', 'bool')
return dtype.name in ("object", "bool")
@classmethod
def is_id(cls, column_name):
return column_name.startswith('id') or column_name.endswith('id')
return column_name.startswith("id") or column_name.endswith("id")
@classmethod
def agg_func(cls, dtype, column_name):
# consider checking for key substring too.
if cls.is_id(column_name):
return 'count_distinct'
if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
np.issubdtype(dtype, np.number)):
return 'sum'
return "count_distinct"
if (
hasattr(dtype, "type")
and issubclass(dtype.type, np.generic)
and np.issubdtype(dtype, np.number)
):
return "sum"
return None
@property
@ -188,42 +195,36 @@ class SupersetDataFrame(object):
if sample_size:
sample = self.df.sample(sample_size)
for col in self.df.dtypes.keys():
db_type_str = (
self._type_dict.get(col) or
self.db_type(self.df.dtypes[col])
)
db_type_str = self._type_dict.get(col) or self.db_type(self.df.dtypes[col])
column = {
'name': col,
'agg': self.agg_func(self.df.dtypes[col], col),
'type': db_type_str,
'is_date': self.is_date(self.df.dtypes[col], db_type_str),
'is_dim': self.is_dimension(self.df.dtypes[col], col),
"name": col,
"agg": self.agg_func(self.df.dtypes[col], col),
"type": db_type_str,
"is_date": self.is_date(self.df.dtypes[col], db_type_str),
"is_dim": self.is_dimension(self.df.dtypes[col], col),
}
if not db_type_str or db_type_str.upper() == 'OBJECT':
if not db_type_str or db_type_str.upper() == "OBJECT":
v = sample[col].iloc[0] if not sample[col].empty else None
if isinstance(v, str):
column['type'] = 'STRING'
column["type"] = "STRING"
elif isinstance(v, int):
column['type'] = 'INT'
column["type"] = "INT"
elif isinstance(v, float):
column['type'] = 'FLOAT'
column["type"] = "FLOAT"
elif isinstance(v, (datetime, date)):
column['type'] = 'DATETIME'
column['is_date'] = True
column['is_dim'] = False
column["type"] = "DATETIME"
column["is_date"] = True
column["is_dim"] = False
# check if encoded datetime
if (
column['type'] == 'STRING' and
self.datetime_conversion_rate(sample[col]) >
INFER_COL_TYPES_THRESHOLD):
column.update({
'is_date': True,
'is_dim': False,
'agg': None,
})
column["type"] == "STRING"
and self.datetime_conversion_rate(sample[col])
> INFER_COL_TYPES_THRESHOLD
):
column.update({"is_date": True, "is_dim": False, "agg": None})
# 'agg' is optional attribute
if not column['agg']:
column.pop('agg', None)
if not column["agg"]:
column.pop("agg", None)
columns.append(column)
return columns

View File

@ -39,11 +39,14 @@ from superset.db_engine_specs.base import BaseEngineSpec
engines: Dict[str, Type[BaseEngineSpec]] = {}
for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore
imported_module = import_module('.' + name, package=__name__)
imported_module = import_module("." + name, package=__name__)
for i in dir(imported_module):
attribute = getattr(imported_module, i)
if inspect.isclass(attribute) and issubclass(attribute, BaseEngineSpec) \
and attribute.engine != '':
if (
inspect.isclass(attribute)
and issubclass(attribute, BaseEngineSpec)
and attribute.engine != ""
):
engines[attribute.engine] = attribute

View File

@ -19,37 +19,36 @@ from superset.db_engine_specs.base import BaseEngineSpec
class AthenaEngineSpec(BaseEngineSpec):
engine = 'awsathena'
engine = "awsathena"
time_grain_functions = {
None: '{col}',
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
'P1W/1970-01-03T00:00:00Z': "date_add('day', 5, date_trunc('week', \
None: "{col}",
"PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))",
"PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))",
"PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))",
"P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))",
"P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))",
"P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))",
"P0.25Y": "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
"P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))",
"P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', \
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
'1969-12-28T00:00:00Z/P1W': "date_add('day', -1, date_trunc('week', \
"1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', \
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
}
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
if tt == "DATE":
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP':
if tt == "TIMESTAMP":
return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
return ("CAST ('{}' AS TIMESTAMP)"
.format(dttm.strftime('%Y-%m-%d %H:%M:%S')))
return "CAST ('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
return "from_unixtime({col})"
@staticmethod
def mutate_label(label):

View File

@ -37,7 +37,7 @@ from werkzeug.utils import secure_filename
from superset import app, db, sql_parse
from superset.utils import core as utils
Grain = namedtuple('Grain', 'name label function duration')
Grain = namedtuple("Grain", "name label function duration")
config = app.config
@ -46,23 +46,23 @@ QueryStatus = utils.QueryStatus
config = app.config
builtin_time_grains = {
None: 'Time Column',
'PT1S': 'second',
'PT1M': 'minute',
'PT5M': '5 minute',
'PT10M': '10 minute',
'PT15M': '15 minute',
'PT0.5H': 'half hour',
'PT1H': 'hour',
'P1D': 'day',
'P1W': 'week',
'P1M': 'month',
'P0.25Y': 'quarter',
'P1Y': 'year',
'1969-12-28T00:00:00Z/P1W': 'week_start_sunday',
'1969-12-29T00:00:00Z/P1W': 'week_start_monday',
'P1W/1970-01-03T00:00:00Z': 'week_ending_saturday',
'P1W/1970-01-04T00:00:00Z': 'week_ending_sunday',
None: "Time Column",
"PT1S": "second",
"PT1M": "minute",
"PT5M": "5 minute",
"PT10M": "10 minute",
"PT15M": "15 minute",
"PT0.5H": "half hour",
"PT1H": "hour",
"P1D": "day",
"P1W": "week",
"P1M": "month",
"P0.25Y": "quarter",
"P1Y": "year",
"1969-12-28T00:00:00Z/P1W": "week_start_sunday",
"1969-12-29T00:00:00Z/P1W": "week_start_monday",
"P1W/1970-01-03T00:00:00Z": "week_ending_saturday",
"P1W/1970-01-04T00:00:00Z": "week_ending_sunday",
}
@ -81,14 +81,15 @@ class TimestampExpression(ColumnClause):
@compiles(TimestampExpression)
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
return element.name.replace('{col}', compiler.process(element.col, **kw))
return element.name.replace("{col}", compiler.process(element.col, **kw))
class LimitMethod(object):
"""Enum the ways that limits can be applied"""
FETCH_MANY = 'fetch_many'
WRAP_SQL = 'wrap_sql'
FORCE_LIMIT = 'force_limit'
FETCH_MANY = "fetch_many"
WRAP_SQL = "wrap_sql"
FORCE_LIMIT = "force_limit"
def create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
@ -104,7 +105,7 @@ def create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations"""
engine = 'base' # str as defined in sqlalchemy.engine.engine
engine = "base" # str as defined in sqlalchemy.engine.engine
time_grain_functions: Dict[Optional[str], str] = {}
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
@ -118,8 +119,9 @@ class BaseEngineSpec(object):
try_remove_schema_from_table_name = True
@classmethod
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
def get_timestamp_expr(
cls, col: ColumnClause, pdf: Optional[str], time_grain: Optional[str]
) -> TimestampExpression:
"""
Construct a TimeExpression to be used in a SQLAlchemy query.
@ -132,25 +134,26 @@ class BaseEngineSpec(object):
time_expr = cls.time_grain_functions.get(time_grain)
if not time_expr:
raise NotImplementedError(
f'No grain spec for {time_grain} for database {cls.engine}')
f"No grain spec for {time_grain} for database {cls.engine}"
)
else:
time_expr = '{col}'
time_expr = "{col}"
# if epoch, translate to DATE using db specific conf
if pdf == 'epoch_s':
time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
elif pdf == 'epoch_ms':
time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
if pdf == "epoch_s":
time_expr = time_expr.replace("{col}", cls.epoch_to_dttm())
elif pdf == "epoch_ms":
time_expr = time_expr.replace("{col}", cls.epoch_ms_to_dttm())
return TimestampExpression(time_expr, col, type_=DateTime)
@classmethod
def get_time_grains(cls):
blacklist = config.get('TIME_GRAIN_BLACKLIST', [])
blacklist = config.get("TIME_GRAIN_BLACKLIST", [])
grains = builtin_time_grains.copy()
grains.update(config.get('TIME_GRAIN_ADDONS', {}))
grains.update(config.get("TIME_GRAIN_ADDONS", {}))
grain_functions = cls.time_grain_functions.copy()
grain_addon_functions = config.get('TIME_GRAIN_ADDON_FUNCTIONS', {})
grain_addon_functions = config.get("TIME_GRAIN_ADDON_FUNCTIONS", {})
grain_functions.update(grain_addon_functions.get(cls.engine, {}))
return create_time_grains_tuple(grains, grain_functions, blacklist)
@ -169,9 +172,9 @@ class BaseEngineSpec(object):
return cursor.fetchall()
@classmethod
def expand_data(cls,
columns: List[dict],
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]:
def expand_data(
cls, columns: List[dict], data: List[dict]
) -> Tuple[List[dict], List[dict], List[dict]]:
return columns, data, []
@classmethod
@ -189,7 +192,7 @@ class BaseEngineSpec(object):
@classmethod
def epoch_ms_to_dttm(cls):
return cls.epoch_to_dttm().replace('{col}', '({col}/1000)')
return cls.epoch_to_dttm().replace("{col}", "({col}/1000)")
@classmethod
def get_datatype(cls, type_code):
@ -205,12 +208,10 @@ class BaseEngineSpec(object):
def apply_limit_to_sql(cls, sql, limit, database):
"""Alters the SQL statement to apply a LIMIT clause"""
if cls.limit_method == LimitMethod.WRAP_SQL:
sql = sql.strip('\t\n ;')
sql = sql.strip("\t\n ;")
qry = (
select('*')
.select_from(
TextAsFrom(text(sql), ['*']).alias('inner_qry'),
)
select("*")
.select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry"))
.limit(limit)
)
return database.compile_sqla_query(qry)
@ -235,10 +236,11 @@ class BaseEngineSpec(object):
:param kwargs: params to be passed to DataFrame.read_csv
:return: Pandas DataFrame containing data from csv
"""
kwargs['filepath_or_buffer'] = \
config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
kwargs['encoding'] = 'utf-8'
kwargs['iterator'] = True
kwargs["filepath_or_buffer"] = (
config["UPLOAD_FOLDER"] + kwargs["filepath_or_buffer"]
)
kwargs["encoding"] = "utf-8"
kwargs["iterator"] = True
chunks = pd.read_csv(**kwargs)
df = pd.concat(chunk for chunk in chunks)
return df
@ -260,39 +262,42 @@ class BaseEngineSpec(object):
:param form: Parameters defining how to process data
:param table: Metadata of new table to be created
"""
def _allowed_file(filename: str) -> bool:
# Only allow specific file extensions as specified in the config
extension = os.path.splitext(filename)[1]
return extension is not None and extension[1:] in config['ALLOWED_EXTENSIONS']
return (
extension is not None and extension[1:] in config["ALLOWED_EXTENSIONS"]
)
filename = secure_filename(form.csv_file.data.filename)
if not _allowed_file(filename):
raise Exception('Invalid file type selected')
raise Exception("Invalid file type selected")
csv_to_df_kwargs = {
'filepath_or_buffer': filename,
'sep': form.sep.data,
'header': form.header.data if form.header.data else 0,
'index_col': form.index_col.data,
'mangle_dupe_cols': form.mangle_dupe_cols.data,
'skipinitialspace': form.skipinitialspace.data,
'skiprows': form.skiprows.data,
'nrows': form.nrows.data,
'skip_blank_lines': form.skip_blank_lines.data,
'parse_dates': form.parse_dates.data,
'infer_datetime_format': form.infer_datetime_format.data,
'chunksize': 10000,
"filepath_or_buffer": filename,
"sep": form.sep.data,
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"skipinitialspace": form.skipinitialspace.data,
"skiprows": form.skiprows.data,
"nrows": form.nrows.data,
"skip_blank_lines": form.skip_blank_lines.data,
"parse_dates": form.parse_dates.data,
"infer_datetime_format": form.infer_datetime_format.data,
"chunksize": 10000,
}
df = cls.csv_to_df(**csv_to_df_kwargs)
df_to_sql_kwargs = {
'df': df,
'name': form.name.data,
'con': create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
'schema': form.schema.data,
'if_exists': form.if_exists.data,
'index': form.index.data,
'index_label': form.index_label.data,
'chunksize': 10000,
"df": df,
"name": form.name.data,
"con": create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
"schema": form.schema.data,
"if_exists": form.if_exists.data,
"index": form.index.data,
"index_label": form.index_label.data,
"chunksize": 10000,
}
cls.df_to_sql(**df_to_sql_kwargs)
@ -304,34 +309,41 @@ class BaseEngineSpec(object):
@classmethod
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
def get_all_datasource_names(
cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
"""Returns a list of all tables or views in database.
:param db: Database instance
:param datasource_type: Datasource_type can be 'table' or 'view'
:return: List of all datasources in database or schema
"""
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
schemas = db.get_all_schema_names(
cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True,
)
all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == 'table':
if datasource_type == "table":
all_datasources += db.get_all_table_names_in_schema(
schema=schema, force=True,
schema=schema,
force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
cache_timeout=db.table_cache_timeout,
)
elif datasource_type == "view":
all_datasources += db.get_all_view_names_in_schema(
schema=schema, force=True,
schema=schema,
force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
cache_timeout=db.table_cache_timeout,
)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
raise Exception(f"Unsupported datasource_type: {datasource_type}")
return all_datasources
@classmethod
@ -381,14 +393,14 @@ class BaseEngineSpec(object):
def get_table_names(cls, inspector, schema):
tables = inspector.get_table_names(schema)
if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
return sorted(tables)
@classmethod
def get_view_names(cls, inspector, schema):
views = inspector.get_view_names(schema)
if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f'^{schema}\\.', '', view) for view in views]
views = [re.sub(f"^{schema}\\.", "", view) for view in views]
return sorted(views)
@classmethod
@ -396,19 +408,27 @@ class BaseEngineSpec(object):
return inspector.get_columns(table_name, schema)
@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
return False
@classmethod
def _get_fields(cls, cols):
return [column(c.get('name')) for c in cols]
return [column(c.get("name")) for c in cols]
@classmethod
def select_star(cls, my_db, table_name, engine, schema=None, limit=100,
show_cols=False, indent=True, latest_partition=True,
cols=None):
fields = '*'
def select_star(
cls,
my_db,
table_name,
engine,
schema=None,
limit=100,
show_cols=False,
indent=True,
latest_partition=True,
cols=None,
):
fields = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
cols = my_db.get_columns(table_name, schema)
@ -417,7 +437,7 @@ class BaseEngineSpec(object):
fields = cls._get_fields(cols)
quote = engine.dialect.identifier_preparer.quote
if schema:
full_table_name = quote(schema) + '.' + quote(table_name)
full_table_name = quote(schema) + "." + quote(table_name)
else:
full_table_name = quote(table_name)
@ -427,7 +447,8 @@ class BaseEngineSpec(object):
qry = qry.limit(limit)
if latest_partition:
partition_query = cls.where_latest_partition(
table_name, schema, my_db, qry, columns=cols)
table_name, schema, my_db, qry, columns=cols
)
if partition_query != False: # noqa
qry = partition_query
sql = my_db.compile_sqla_query(qry)
@ -475,7 +496,10 @@ class BaseEngineSpec(object):
generate a truncated label by calling truncate_label().
"""
label_mutated = cls.mutate_label(label)
if cls.max_column_name_length and len(label_mutated) > cls.max_column_name_length:
if (
cls.max_column_name_length
and len(label_mutated) > cls.max_column_name_length
):
label_mutated = cls.truncate_label(label)
if cls.force_column_alias_quotes:
label_mutated = quoted_name(label_mutated, True)
@ -510,10 +534,10 @@ class BaseEngineSpec(object):
this method is used to construct a deterministic and unique label based on
an md5 hash.
"""
label = hashlib.md5(label.encode('utf-8')).hexdigest()
label = hashlib.md5(label.encode("utf-8")).hexdigest()
# truncate hash if it exceeds max length
if cls.max_column_name_length and len(label) > cls.max_column_name_length:
label = label[:cls.max_column_name_length]
label = label[: cls.max_column_name_length]
return label
@classmethod

View File

@ -27,7 +27,8 @@ class BigQueryEngineSpec(BaseEngineSpec):
"""Engine spec for Google's BigQuery
As contributed by @mxmzdlv on issue #945"""
engine = 'bigquery'
engine = "bigquery"
max_column_name_length = 128
"""
@ -43,28 +44,28 @@ class BigQueryEngineSpec(BaseEngineSpec):
arraysize = 5000
time_grain_functions = {
None: '{col}',
'PT1S': 'TIMESTAMP_TRUNC({col}, SECOND)',
'PT1M': 'TIMESTAMP_TRUNC({col}, MINUTE)',
'PT1H': 'TIMESTAMP_TRUNC({col}, HOUR)',
'P1D': 'TIMESTAMP_TRUNC({col}, DAY)',
'P1W': 'TIMESTAMP_TRUNC({col}, WEEK)',
'P1M': 'TIMESTAMP_TRUNC({col}, MONTH)',
'P0.25Y': 'TIMESTAMP_TRUNC({col}, QUARTER)',
'P1Y': 'TIMESTAMP_TRUNC({col}, YEAR)',
None: "{col}",
"PT1S": "TIMESTAMP_TRUNC({col}, SECOND)",
"PT1M": "TIMESTAMP_TRUNC({col}, MINUTE)",
"PT1H": "TIMESTAMP_TRUNC({col}, HOUR)",
"P1D": "TIMESTAMP_TRUNC({col}, DAY)",
"P1W": "TIMESTAMP_TRUNC({col}, WEEK)",
"P1M": "TIMESTAMP_TRUNC({col}, MONTH)",
"P0.25Y": "TIMESTAMP_TRUNC({col}, QUARTER)",
"P1Y": "TIMESTAMP_TRUNC({col}, YEAR)",
}
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
return "'{}'".format(dttm.strftime('%Y-%m-%d'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
if tt == "DATE":
return "'{}'".format(dttm.strftime("%Y-%m-%d"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def fetch_data(cls, cursor, limit):
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == 'Row':
if data and type(data[0]).__name__ == "Row":
data = [r.values() for r in data]
return data
@ -78,13 +79,13 @@ class BigQueryEngineSpec(BaseEngineSpec):
:param str label: the original label which might include unsupported characters
:return: String that is supported by the database
"""
label_hashed = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
label_hashed = "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
# if label starts with number, add underscore as first character
label_mutated = '_' + label if re.match(r'^\d', label) else label
label_mutated = "_" + label if re.match(r"^\d", label) else label
# replace non-alphanumeric characters with underscores
label_mutated = re.sub(r'[^\w]+', '_', label_mutated)
label_mutated = re.sub(r"[^\w]+", "_", label_mutated)
if label_mutated != label:
# add first 5 chars from md5 hash to label to avoid possible collisions
label_mutated += label_hashed[:6]
@ -97,7 +98,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
underscore. To make sure this is always the case, an underscore is prefixed
to the truncated label.
"""
return '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
return "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
@ -105,20 +106,18 @@ class BigQueryEngineSpec(BaseEngineSpec):
if not indexes:
return {}
partitions_columns = [
index.get('column_names', []) for index in indexes
if index.get('name') == 'partition'
index.get("column_names", [])
for index in indexes
if index.get("name") == "partition"
]
cluster_columns = [
index.get('column_names', []) for index in indexes
if index.get('name') == 'clustering'
index.get("column_names", [])
for index in indexes
if index.get("name") == "clustering"
]
return {
'partitions': {
'cols': partitions_columns,
},
'clustering': {
'cols': cluster_columns,
},
"partitions": {"cols": partitions_columns},
"clustering": {"cols": cluster_columns},
}
@classmethod
@ -131,16 +130,18 @@ class BigQueryEngineSpec(BaseEngineSpec):
Also explicility specifying column names so we don't encounter duplicate
column names in the result.
"""
return [literal_column(c.get('name')).label(c.get('name').replace('.', '__'))
for c in cols]
return [
literal_column(c.get("name")).label(c.get("name").replace(".", "__"))
for c in cols
]
@classmethod
def epoch_to_dttm(cls):
return 'TIMESTAMP_SECONDS({col})'
return "TIMESTAMP_SECONDS({col})"
@classmethod
def epoch_ms_to_dttm(cls):
return 'TIMESTAMP_MILLIS({col})'
return "TIMESTAMP_MILLIS({col})"
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
@ -156,18 +157,20 @@ class BigQueryEngineSpec(BaseEngineSpec):
try:
import pandas_gbq
except ImportError:
raise Exception('Could not import the library `pandas_gbq`, which is '
'required to be installed in your environment in order '
'to upload data to BigQuery')
raise Exception(
"Could not import the library `pandas_gbq`, which is "
"required to be installed in your environment in order "
"to upload data to BigQuery"
)
if not ('name' in kwargs and 'schema' in kwargs):
raise Exception('name and schema need to be defined in kwargs')
if not ("name" in kwargs and "schema" in kwargs):
raise Exception("name and schema need to be defined in kwargs")
gbq_kwargs = {}
gbq_kwargs['project_id'] = kwargs['con'].engine.url.host
gbq_kwargs['destination_table'] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"
gbq_kwargs["project_id"] = kwargs["con"].engine.url.host
gbq_kwargs["destination_table"] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"
# Only pass through supported kwargs
supported_kwarg_keys = {'if_exists'}
supported_kwarg_keys = {"if_exists"}
for key in supported_kwarg_keys:
if key in kwargs:
gbq_kwargs[key] = kwargs[key]

View File

@ -21,32 +21,31 @@ from superset.db_engine_specs.base import BaseEngineSpec
class ClickHouseEngineSpec(BaseEngineSpec):
"""Dialect for ClickHouse analytical DB."""
engine = 'clickhouse'
engine = "clickhouse"
time_secondary_columns = True
time_groupby_inline = True
time_grain_functions = {
None: '{col}',
'PT1M': 'toStartOfMinute(toDateTime({col}))',
'PT5M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)',
'PT10M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)',
'PT15M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)',
'PT0.5H': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)',
'PT1H': 'toStartOfHour(toDateTime({col}))',
'P1D': 'toStartOfDay(toDateTime({col}))',
'P1W': 'toMonday(toDateTime({col}))',
'P1M': 'toStartOfMonth(toDateTime({col}))',
'P0.25Y': 'toStartOfQuarter(toDateTime({col}))',
'P1Y': 'toStartOfYear(toDateTime({col}))',
None: "{col}",
"PT1M": "toStartOfMinute(toDateTime({col}))",
"PT5M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)",
"PT10M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)",
"PT15M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)",
"PT0.5H": "toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)",
"PT1H": "toStartOfHour(toDateTime({col}))",
"P1D": "toStartOfDay(toDateTime({col}))",
"P1W": "toMonday(toDateTime({col}))",
"P1M": "toStartOfMonth(toDateTime({col}))",
"P0.25Y": "toStartOfQuarter(toDateTime({col}))",
"P1Y": "toStartOfYear(toDateTime({col}))",
}
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
return "toDate('{}')".format(dttm.strftime('%Y-%m-%d'))
if tt == 'DATETIME':
return "toDateTime('{}')".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
if tt == "DATE":
return "toDate('{}')".format(dttm.strftime("%Y-%m-%d"))
if tt == "DATETIME":
return "toDateTime('{}')".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))

View File

@ -19,34 +19,32 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
engine = "ibm_db_sa"
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
max_column_name_length = 30
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
' - MICROSECOND({col}) MICROSECONDS',
'PT1M': 'CAST({col} as TIMESTAMP)'
' - SECOND({col}) SECONDS'
' - MICROSECOND({col}) MICROSECONDS',
'PT1H': 'CAST({col} as TIMESTAMP)'
' - MINUTE({col}) MINUTES'
' - SECOND({col}) SECONDS'
' - MICROSECOND({col}) MICROSECONDS ',
'P1D': 'CAST({col} as TIMESTAMP)'
' - HOUR({col}) HOURS'
' - MINUTE({col}) MINUTES'
' - SECOND({col}) SECONDS'
' - MICROSECOND({col}) MICROSECONDS',
'P1W': '{col} - (DAYOFWEEK({col})) DAYS',
'P1M': '{col} - (DAY({col})-1) DAYS',
'P0.25Y': '{col} - (DAY({col})-1) DAYS'
' - (MONTH({col})-1) MONTHS'
' + ((QUARTER({col})-1) * 3) MONTHS',
'P1Y': '{col} - (DAY({col})-1) DAYS'
' - (MONTH({col})-1) MONTHS',
None: "{col}",
"PT1S": "CAST({col} as TIMESTAMP)" " - MICROSECOND({col}) MICROSECONDS",
"PT1M": "CAST({col} as TIMESTAMP)"
" - SECOND({col}) SECONDS"
" - MICROSECOND({col}) MICROSECONDS",
"PT1H": "CAST({col} as TIMESTAMP)"
" - MINUTE({col}) MINUTES"
" - SECOND({col}) SECONDS"
" - MICROSECOND({col}) MICROSECONDS ",
"P1D": "CAST({col} as TIMESTAMP)"
" - HOUR({col}) HOURS"
" - MINUTE({col}) MINUTES"
" - SECOND({col}) SECONDS"
" - MICROSECOND({col}) MICROSECONDS",
"P1W": "{col} - (DAYOFWEEK({col})) DAYS",
"P1M": "{col} - (DAY({col})-1) DAYS",
"P0.25Y": "{col} - (DAY({col})-1) DAYS"
" - (MONTH({col})-1) MONTHS"
" + ((QUARTER({col})-1) * 3) MONTHS",
"P1Y": "{col} - (DAY({col})-1) DAYS" " - (MONTH({col})-1) MONTHS",
}
@classmethod
@ -55,4 +53,4 @@ class Db2EngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
return "'{}'".format(dttm.strftime("%Y-%m-%d-%H.%M.%S"))

View File

@ -22,43 +22,43 @@ from superset.db_engine_specs.base import BaseEngineSpec
class DrillEngineSpec(BaseEngineSpec):
"""Engine spec for Apache Drill"""
engine = 'drill'
engine = "drill"
time_grain_functions = {
None: '{col}',
'PT1S': "NEARESTDATE({col}, 'SECOND')",
'PT1M': "NEARESTDATE({col}, 'MINUTE')",
'PT15M': "NEARESTDATE({col}, 'QUARTER_HOUR')",
'PT0.5H': "NEARESTDATE({col}, 'HALF_HOUR')",
'PT1H': "NEARESTDATE({col}, 'HOUR')",
'P1D': "NEARESTDATE({col}, 'DAY')",
'P1W': "NEARESTDATE({col}, 'WEEK_SUNDAY')",
'P1M': "NEARESTDATE({col}, 'MONTH')",
'P0.25Y': "NEARESTDATE({col}, 'QUARTER')",
'P1Y': "NEARESTDATE({col}, 'YEAR')",
None: "{col}",
"PT1S": "NEARESTDATE({col}, 'SECOND')",
"PT1M": "NEARESTDATE({col}, 'MINUTE')",
"PT15M": "NEARESTDATE({col}, 'QUARTER_HOUR')",
"PT0.5H": "NEARESTDATE({col}, 'HALF_HOUR')",
"PT1H": "NEARESTDATE({col}, 'HOUR')",
"P1D": "NEARESTDATE({col}, 'DAY')",
"P1W": "NEARESTDATE({col}, 'WEEK_SUNDAY')",
"P1M": "NEARESTDATE({col}, 'MONTH')",
"P0.25Y": "NEARESTDATE({col}, 'QUARTER')",
"P1Y": "NEARESTDATE({col}, 'YEAR')",
}
# Returns a function to convert a Unix timestamp in milliseconds to a date
@classmethod
def epoch_to_dttm(cls):
return cls.epoch_ms_to_dttm().replace('{col}', '({col}*1000)')
return cls.epoch_ms_to_dttm().replace("{col}", "({col}*1000)")
@classmethod
def epoch_ms_to_dttm(cls):
return 'TO_DATE({col})'
return "TO_DATE({col})"
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
elif tt == 'TIMESTAMP':
return "CAST('{}' AS TIMESTAMP)".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
elif tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def adjust_database_uri(cls, uri, selected_schema):
if selected_schema:
uri.database = parse.quote(selected_schema, safe='')
uri.database = parse.quote(selected_schema, safe="")
return uri

View File

@ -20,23 +20,24 @@ from superset.db_engine_specs.base import BaseEngineSpec
class DruidEngineSpec(BaseEngineSpec):
"""Engine spec for Druid.io"""
engine = 'druid'
engine = "druid"
inner_joins = False
allows_subquery = False
time_grain_functions = {
None: '{col}',
'PT1S': 'FLOOR({col} TO SECOND)',
'PT1M': 'FLOOR({col} TO MINUTE)',
'PT1H': 'FLOOR({col} TO HOUR)',
'P1D': 'FLOOR({col} TO DAY)',
'P1W': 'FLOOR({col} TO WEEK)',
'P1M': 'FLOOR({col} TO MONTH)',
'P0.25Y': 'FLOOR({col} TO QUARTER)',
'P1Y': 'FLOOR({col} TO YEAR)',
None: "{col}",
"PT1S": "FLOOR({col} TO SECOND)",
"PT1M": "FLOOR({col} TO MINUTE)",
"PT1H": "FLOOR({col} TO HOUR)",
"P1D": "FLOOR({col} TO DAY)",
"P1W": "FLOOR({col} TO WEEK)",
"P1M": "FLOOR({col} TO MONTH)",
"P0.25Y": "FLOOR({col} TO QUARTER)",
"P1Y": "FLOOR({col} TO YEAR)",
}
@classmethod
def alter_new_orm_column(cls, orm_col):
if orm_col.column_name == '__time':
if orm_col.column_name == "__time":
orm_col.is_dttm = True

View File

@ -20,6 +20,7 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
class GSheetsEngineSpec(SqliteEngineSpec):
"""Engine for Google spreadsheets"""
engine = 'gsheets'
engine = "gsheets"
inner_joins = False
allows_subquery = False

View File

@ -38,30 +38,30 @@ from superset.utils import core as utils
QueryStatus = utils.QueryStatus
config = app.config
tracking_url_trans = conf.get('TRACKING_URL_TRANSFORMER')
hive_poll_interval = conf.get('HIVE_POLL_INTERVAL')
tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality."""
engine = 'hive'
engine = "hive"
max_column_name_length = 767
# Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
jobs_stats_r = re.compile(
r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
jobs_stats_r = re.compile(r".*INFO.*Total jobs = (?P<max_jobs>[0-9]+)")
# 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
launching_job_r = re.compile(
'.*INFO.*Launching Job (?P<job_number>[0-9]+) out of '
'(?P<max_jobs>[0-9]+)')
".*INFO.*Launching Job (?P<job_number>[0-9]+) out of " "(?P<max_jobs>[0-9]+)"
)
# 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
# map = 0%, reduce = 0%
stage_progress_r = re.compile(
r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*'
r'map = (?P<map_progress>[0-9]+)%.*'
r'reduce = (?P<reduce_progress>[0-9]+)%.*')
r".*INFO.*Stage-(?P<stage_number>[0-9]+).*"
r"map = (?P<map_progress>[0-9]+)%.*"
r"reduce = (?P<reduce_progress>[0-9]+)%.*"
)
@classmethod
def patch(cls):
@ -70,7 +70,8 @@ class HiveEngineSpec(PrestoEngineSpec):
from TCLIService import (
constants as patched_constants,
ttypes as patched_ttypes,
TCLIService as patched_TCLIService)
TCLIService as patched_TCLIService,
)
hive.TCLIService = patched_TCLIService
hive.constants = patched_constants
@ -78,17 +79,19 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = patched_hive.fetch_logs
@classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
def get_all_datasource_names(
cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit):
import pyhive
from TCLIService import ttypes
state = cursor.poll()
if state.operationState == ttypes.TOperationState.ERROR_STATE:
raise Exception('Query error', state.errorMessage)
raise Exception("Query error", state.errorMessage)
try:
return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
except pyhive.exc.ProgrammingError:
@ -97,68 +100,76 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def create_table_from_csv(cls, form, table):
"""Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type):
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
'boolean': 'BOOLEAN',
'integer': 'INT',
'number': 'DOUBLE',
'string': 'STRING',
"boolean": "BOOLEAN",
"integer": "INT",
"number": "DOUBLE",
"string": "STRING",
}
return tableschema_to_hive_types.get(col_type, 'STRING')
return tableschema_to_hive_types.get(col_type, "STRING")
bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
if not bucket_path:
logging.info('No upload bucket specified')
logging.info("No upload bucket specified")
raise Exception(
'No upload bucket specified. You can specify one in the config file.')
"No upload bucket specified. You can specify one in the config file."
)
table_name = form.name.data
schema_name = form.schema.data
if config.get('UPLOADED_CSV_HIVE_NAMESPACE'):
if '.' in table_name or schema_name:
if config.get("UPLOADED_CSV_HIVE_NAMESPACE"):
if "." in table_name or schema_name:
raise Exception(
"You can't specify a namespace. "
'All tables will be uploaded to the `{}` namespace'.format(
config.get('HIVE_NAMESPACE')))
full_table_name = '{}.{}'.format(
config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
"All tables will be uploaded to the `{}` namespace".format(
config.get("HIVE_NAMESPACE")
)
)
full_table_name = "{}.{}".format(
config.get("UPLOADED_CSV_HIVE_NAMESPACE"), table_name
)
else:
if '.' in table_name and schema_name:
if "." in table_name and schema_name:
raise Exception(
"You can't specify a namespace both in the name of the table "
'and in the schema field. Please remove one')
"and in the schema field. Please remove one"
)
full_table_name = '{}.{}'.format(
schema_name, table_name) if schema_name else table_name
full_table_name = (
"{}.{}".format(schema_name, table_name) if schema_name else table_name
)
filename = form.csv_file.data.filename
upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
upload_path = config['UPLOAD_FOLDER'] + \
secure_filename(filename)
upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
upload_path = config["UPLOAD_FOLDER"] + secure_filename(filename)
# Optional dependency
from tableschema import Table # pylint: disable=import-error
hive_table_schema = Table(upload_path).infer()
column_name_and_type = []
for column_info in hive_table_schema['fields']:
for column_info in hive_table_schema["fields"]:
column_name_and_type.append(
'`{}` {}'.format(
column_info['name'],
convert_to_hive_type(column_info['type'])))
schema_definition = ', '.join(column_name_and_type)
"`{}` {}".format(
column_info["name"], convert_to_hive_type(column_info["type"])
)
)
schema_definition = ", ".join(column_name_and_type)
# Optional dependency
import boto3 # pylint: disable=import-error
s3 = boto3.client('s3')
location = os.path.join('s3a://', bucket_path, upload_prefix, table_name)
s3 = boto3.client("s3")
location = os.path.join("s3a://", bucket_path, upload_prefix, table_name)
s3.upload_file(
upload_path, bucket_path,
os.path.join(upload_prefix, table_name, filename))
upload_path, bucket_path, os.path.join(upload_prefix, table_name, filename)
)
sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
TEXTFILE LOCATION '{location}'
@ -170,17 +181,16 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
elif tt == 'TIMESTAMP':
return "CAST('{}' AS TIMESTAMP)".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
elif tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = parse.quote(selected_schema, safe='')
uri.database = parse.quote(selected_schema, safe="")
return uri
@classmethod
@ -199,34 +209,32 @@ class HiveEngineSpec(PrestoEngineSpec):
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
total_jobs = int(match.groupdict()['max_jobs']) or 1
total_jobs = int(match.groupdict()["max_jobs"]) or 1
match = cls.launching_job_r.match(line)
if match:
current_job = int(match.groupdict()['job_number'])
total_jobs = int(match.groupdict()['max_jobs']) or 1
current_job = int(match.groupdict()["job_number"])
total_jobs = int(match.groupdict()["max_jobs"]) or 1
stages = {}
match = cls.stage_progress_r.match(line)
if match:
stage_number = int(match.groupdict()['stage_number'])
map_progress = int(match.groupdict()['map_progress'])
reduce_progress = int(match.groupdict()['reduce_progress'])
stage_number = int(match.groupdict()["stage_number"])
map_progress = int(match.groupdict()["map_progress"])
reduce_progress = int(match.groupdict()["reduce_progress"])
stages[stage_number] = (map_progress + reduce_progress) / 2
logging.info(
'Progress detail: {}, '
'current job {}, '
'total jobs: {}'.format(stages, current_job, total_jobs))
stage_progress = sum(
stages.values()) / len(stages.values()) if stages else 0
progress = (
100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
"Progress detail: {}, "
"current job {}, "
"total jobs: {}".format(stages, current_job, total_jobs)
)
stage_progress = sum(stages.values()) / len(stages.values()) if stages else 0
progress = 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
return int(progress)
@classmethod
def get_tracking_url(cls, log_lines):
lkp = 'Tracking URL = '
lkp = "Tracking URL = "
for line in log_lines:
if lkp in line:
return line.split(lkp)[1]
@ -235,6 +243,7 @@ class HiveEngineSpec(PrestoEngineSpec):
def handle_cursor(cls, cursor, query, session):
"""Updates progress information"""
from pyhive import hive # pylint: disable=no-name-in-module
unfinished_states = (
hive.ttypes.TOperationState.INITIALIZED_STATE,
hive.ttypes.TOperationState.RUNNING_STATE,
@ -249,11 +258,11 @@ class HiveEngineSpec(PrestoEngineSpec):
cursor.cancel()
break
log = cursor.fetch_logs() or ''
log = cursor.fetch_logs() or ""
if log:
log_lines = log.splitlines()
progress = cls.progress(log_lines)
logging.info('Progress total: {}'.format(progress))
logging.info("Progress total: {}".format(progress))
needs_commit = False
if progress > query.progress:
query.progress = progress
@ -261,21 +270,19 @@ class HiveEngineSpec(PrestoEngineSpec):
if not tracking_url:
tracking_url = cls.get_tracking_url(log_lines)
if tracking_url:
job_id = tracking_url.split('/')[-2]
logging.info(
'Found the tracking url: {}'.format(tracking_url))
job_id = tracking_url.split("/")[-2]
logging.info("Found the tracking url: {}".format(tracking_url))
tracking_url = tracking_url_trans(tracking_url)
logging.info(
'Transformation applied: {}'.format(tracking_url))
logging.info("Transformation applied: {}".format(tracking_url))
query.tracking_url = tracking_url
logging.info('Job id: {}'.format(job_id))
logging.info("Job id: {}".format(job_id))
needs_commit = True
if job_id and len(log_lines) > last_log_line:
# Wait for job id before logging things out
# this allows for prefixing all log lines and becoming
# searchable in something like Kibana
for l in log_lines[last_log_line:]:
logging.info('[{}] {}'.format(job_id, l))
logging.info("[{}] {}".format(job_id, l))
last_log_line = len(log_lines)
if needs_commit:
session.commit()
@ -284,21 +291,22 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]:
cls, inspector: Inspector, table_name: str, schema: str
) -> List[dict]:
return inspector.get_columns(table_name, schema)
@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
try:
col_name, value = cls.latest_partition(
table_name, schema, database, show_first=True)
table_name, schema, database, show_first=True
)
except Exception:
# table is not partitioned
return False
if value is not None:
for c in columns:
if c.get('name') == col_name:
if c.get("name") == col_name:
return qry.where(Column(col_name) == value)
return False
@ -315,20 +323,36 @@ class HiveEngineSpec(PrestoEngineSpec):
def _latest_partition_from_df(cls, df):
"""Hive partitions look like ds={partition name}"""
if not df.empty:
return df.ix[:, 0].max().split('=')[1]
return df.ix[:, 0].max().split("=")[1]
@classmethod
def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
return f'SHOW PARTITIONS {table_name}'
def _partition_query(cls, table_name, limit=0, order_by=None, filters=None):
return f"SHOW PARTITIONS {table_name}"
@classmethod
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None,
limit: int = 100, show_cols: bool = False, indent: bool = True,
latest_partition: bool = True, cols: List[dict] = []) -> str:
def select_star(
cls,
my_db,
table_name: str,
engine: Engine,
schema: str = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: List[dict] = [],
) -> str:
return BaseEngineSpec.select_star(
my_db, table_name, engine, schema, limit,
show_cols, indent, latest_partition, cols)
my_db,
table_name,
engine,
schema,
limit,
show_cols,
indent,
latest_partition,
cols,
)
@classmethod
def modify_url_for_impersonation(cls, url, impersonate_user, username):
@ -356,13 +380,18 @@ class HiveEngineSpec(PrestoEngineSpec):
url = make_url(uri)
backend_name = url.get_backend_name()
# Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS
if (backend_name == 'hive' and 'auth' in url.query.keys() and
impersonate_user is True and username is not None):
configuration['hive.server2.proxy.user'] = username
# Must be Hive connection, enable impersonation, and set param
# auth=LDAP|KERBEROS
if (
backend_name == "hive"
and "auth" in url.query.keys()
and impersonate_user is True
and username is not None
):
configuration["hive.server2.proxy.user"] = username
return configuration
@staticmethod
def execute(cursor, query, async_=False):
kwargs = {'async': async_}
kwargs = {"async": async_}
cursor.execute(query, **kwargs)

View File

@ -21,32 +21,35 @@ from superset.db_engine_specs.base import BaseEngineSpec
class ImpalaEngineSpec(BaseEngineSpec):
"""Engine spec for Cloudera's Impala"""
engine = 'impala'
engine = "impala"
time_grain_functions = {
None: '{col}',
'PT1M': "TRUNC({col}, 'MI')",
'PT1H': "TRUNC({col}, 'HH')",
'P1D': "TRUNC({col}, 'DD')",
'P1W': "TRUNC({col}, 'WW')",
'P1M': "TRUNC({col}, 'MONTH')",
'P0.25Y': "TRUNC({col}, 'Q')",
'P1Y': "TRUNC({col}, 'YYYY')",
None: "{col}",
"PT1M": "TRUNC({col}, 'MI')",
"PT1H": "TRUNC({col}, 'HH')",
"P1D": "TRUNC({col}, 'DD')",
"P1W": "TRUNC({col}, 'WW')",
"P1M": "TRUNC({col}, 'MONTH')",
"P0.25Y": "TRUNC({col}, 'Q')",
"P1Y": "TRUNC({col}, 'YYYY')",
}
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
return "from_unixtime({col})"
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
return "'{}'".format(dttm.strftime('%Y-%m-%d'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
if tt == "DATE":
return "'{}'".format(dttm.strftime("%Y-%m-%d"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def get_schema_names(cls, inspector):
schemas = [row[0] for row in inspector.engine.execute('SHOW SCHEMAS')
if not row[0].startswith('_')]
schemas = [
row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
if not row[0].startswith("_")
]
return schemas

View File

@ -21,28 +21,27 @@ from superset.db_engine_specs.base import BaseEngineSpec
class KylinEngineSpec(BaseEngineSpec):
"""Dialect for Apache Kylin"""
engine = 'kylin'
engine = "kylin"
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)',
'PT1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)',
'PT1H': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)',
'P1D': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)',
'P1W': 'CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
'P1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)',
'P0.25Y': 'CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
'P1Y': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)',
None: "{col}",
"PT1S": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)",
"PT1M": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)",
"PT1H": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)",
"P1D": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)",
"P1W": "CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)",
"P1M": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)",
"P0.25Y": "CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \
FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)",
"P1Y": "CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)",
}
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
if tt == "DATE":
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP':
return "CAST('{}' AS TIMESTAMP)".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
if tt == "TIMESTAMP":
return "CAST('{}' AS TIMESTAMP)".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))

View File

@ -23,25 +23,25 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
engine = "mssql"
epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
time_grain_functions = {
None: '{col}',
'PT1S': "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')",
'PT1M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)',
'PT5M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)',
'PT10M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)',
'PT15M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)',
'PT0.5H': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)',
'PT1H': 'DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)',
'P1D': 'DATEADD(day, DATEDIFF(day, 0, {col}), 0)',
'P1W': 'DATEADD(week, DATEDIFF(week, 0, {col}), 0)',
'P1M': 'DATEADD(month, DATEDIFF(month, 0, {col}), 0)',
'P0.25Y': 'DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)',
'P1Y': 'DATEADD(year, DATEDIFF(year, 0, {col}), 0)',
None: "{col}",
"PT1S": "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')",
"PT1M": "DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)",
"PT5M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)",
"PT10M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)",
"PT15M": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)",
"PT0.5H": "DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)",
"PT1H": "DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)",
"P1D": "DATEADD(day, DATEDIFF(day, 0, {col}), 0)",
"P1W": "DATEADD(week, DATEDIFF(week, 0, {col}), 0)",
"P1M": "DATEADD(month, DATEDIFF(month, 0, {col}), 0)",
"P0.25Y": "DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)",
"P1Y": "DATEADD(year, DATEDIFF(year, 0, {col}), 0)",
}
@classmethod
@ -51,13 +51,13 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(cls, cursor, limit):
data = super(MssqlEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == 'Row':
if data and type(data[0]).__name__ == "Row":
data = [[elem for elem in r] for r in data]
return data
column_types = [
(String(), re.compile(r'^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)', re.IGNORECASE)),
(UnicodeText(), re.compile(r'^N((VAR){0,1}CHAR|TEXT)', re.IGNORECASE)),
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
]
@classmethod

View File

@ -22,45 +22,42 @@ from superset.db_engine_specs.base import BaseEngineSpec
class MySQLEngineSpec(BaseEngineSpec):
engine = 'mysql'
engine = "mysql"
max_column_name_length = 64
time_grain_functions = {
None: '{col}',
'PT1S': 'DATE_ADD(DATE({col}), '
'INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60'
' + SECOND({col})) SECOND)',
'PT1M': 'DATE_ADD(DATE({col}), '
'INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)',
'PT1H': 'DATE_ADD(DATE({col}), '
'INTERVAL HOUR({col}) HOUR)',
'P1D': 'DATE({col})',
'P1W': 'DATE(DATE_SUB({col}, '
'INTERVAL DAYOFWEEK({col}) - 1 DAY))',
'P1M': 'DATE(DATE_SUB({col}, '
'INTERVAL DAYOFMONTH({col}) - 1 DAY))',
'P0.25Y': 'MAKEDATE(YEAR({col}), 1) '
'+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER',
'P1Y': 'DATE(DATE_SUB({col}, '
'INTERVAL DAYOFYEAR({col}) - 1 DAY))',
'1969-12-29T00:00:00Z/P1W': 'DATE(DATE_SUB({col}, '
'INTERVAL DAYOFWEEK(DATE_SUB({col}, '
'INTERVAL 1 DAY)) - 1 DAY))',
None: "{col}",
"PT1S": "DATE_ADD(DATE({col}), "
"INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60"
" + SECOND({col})) SECOND)",
"PT1M": "DATE_ADD(DATE({col}), "
"INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)",
"PT1H": "DATE_ADD(DATE({col}), " "INTERVAL HOUR({col}) HOUR)",
"P1D": "DATE({col})",
"P1W": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFWEEK({col}) - 1 DAY))",
"P1M": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFMONTH({col}) - 1 DAY))",
"P0.25Y": "MAKEDATE(YEAR({col}), 1) "
"+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER",
"P1Y": "DATE(DATE_SUB({col}, " "INTERVAL DAYOFYEAR({col}) - 1 DAY))",
"1969-12-29T00:00:00Z/P1W": "DATE(DATE_SUB({col}, "
"INTERVAL DAYOFWEEK(DATE_SUB({col}, "
"INTERVAL 1 DAY)) - 1 DAY))",
}
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
def convert_dttm(cls, target_type, dttm):
if target_type.upper() in ('DATETIME', 'DATE'):
if target_type.upper() in ("DATETIME", "DATE"):
return "STR_TO_DATE('{}', '%Y-%m-%d %H:%i:%s')".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
dttm.strftime("%Y-%m-%d %H:%M:%S")
)
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = parse.quote(selected_schema, safe='')
uri.database = parse.quote(selected_schema, safe="")
return uri
@classmethod
@ -68,11 +65,10 @@ class MySQLEngineSpec(BaseEngineSpec):
if not cls.type_code_map:
# only import and store if needed at least once
import MySQLdb # pylint: disable=import-error
ft = MySQLdb.constants.FIELD_TYPE
cls.type_code_map = {
getattr(ft, k): k
for k in dir(ft)
if not k.startswith('_')
getattr(ft, k): k for k in dir(ft) if not k.startswith("_")
}
datatype = type_code
if isinstance(type_code, int):
@ -82,7 +78,7 @@ class MySQLEngineSpec(BaseEngineSpec):
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
return "from_unixtime({col})"
@classmethod
def extract_error_message(cls, e):
@ -101,7 +97,7 @@ class MySQLEngineSpec(BaseEngineSpec):
# MySQL dialect started returning long overflowing datatype
# as in 'VARCHAR(255) COLLATE UTF8MB4_GENERAL_CI'
# and we don't need the verbose collation type
str_cutoff = ' COLLATE '
str_cutoff = " COLLATE "
if str_cutoff in datatype:
datatype = datatype.split(str_cutoff)[0]
return datatype

View File

@ -20,25 +20,25 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
engine = "oracle"
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
max_column_name_length = 30
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as DATE)',
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')",
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')",
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')",
'P1W': "TRUNC(CAST({col} as DATE), 'WW')",
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')",
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')",
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')",
None: "{col}",
"PT1S": "CAST({col} as DATE)",
"PT1M": "TRUNC(CAST({col} as DATE), 'MI')",
"PT1H": "TRUNC(CAST({col} as DATE), 'HH')",
"P1D": "TRUNC(CAST({col} as DATE), 'DDD')",
"P1W": "TRUNC(CAST({col} as DATE), 'WW')",
"P1M": "TRUNC(CAST({col} as DATE), 'MONTH')",
"P0.25Y": "TRUNC(CAST({col} as DATE), 'Q')",
"P1Y": "TRUNC(CAST({col} as DATE), 'YEAR')",
}
@classmethod
def convert_dttm(cls, target_type, dttm):
return (
"""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
).format(dttm.isoformat())
return ("""TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""").format(
dttm.isoformat()
)

View File

@ -23,37 +23,38 @@ from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
class PinotEngineSpec(BaseEngineSpec):
engine = 'pinot'
engine = "pinot"
allows_subquery = False
inner_joins = False
supports_column_aliases = False
# Pinot does its own conversion below
time_grain_functions: Dict[Optional[str], str] = {
'PT1S': '1:SECONDS',
'PT1M': '1:MINUTES',
'PT1H': '1:HOURS',
'P1D': '1:DAYS',
'P1W': '1:WEEKS',
'P1M': '1:MONTHS',
'P0.25Y': '3:MONTHS',
'P1Y': '1:YEARS',
"PT1S": "1:SECONDS",
"PT1M": "1:MINUTES",
"PT1H": "1:HOURS",
"P1D": "1:DAYS",
"P1W": "1:WEEKS",
"P1M": "1:MONTHS",
"P0.25Y": "3:MONTHS",
"P1Y": "1:YEARS",
}
@classmethod
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
is_epoch = pdf in ('epoch_s', 'epoch_ms')
def get_timestamp_expr(
cls, col: ColumnClause, pdf: Optional[str], time_grain: Optional[str]
) -> TimestampExpression:
is_epoch = pdf in ("epoch_s", "epoch_ms")
if not is_epoch:
raise NotImplementedError('Pinot currently only supports epochs')
raise NotImplementedError("Pinot currently only supports epochs")
# The DATETIMECONVERT pinot udf is documented at
# Per https://github.com/apache/incubator-pinot/wiki/dateTimeConvert-UDF
# We are not really converting any time units, just bucketing them.
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
tf = f'1:{seconds_or_ms}:EPOCH'
seconds_or_ms = "MILLISECONDS" if pdf == "epoch_ms" else "SECONDS"
tf = f"1:{seconds_or_ms}:EPOCH"
granularity = cls.time_grain_functions.get(time_grain)
if not granularity:
raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
raise NotImplementedError("No pinot grain spec for " + str(time_grain))
# In pinot the output is a string since there is no timestamp column like pg
time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
return TimestampExpression(time_expr, col)

View File

@ -21,18 +21,18 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
engine = ''
engine = ""
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('second', {col})",
'PT1M': "DATE_TRUNC('minute', {col})",
'PT1H': "DATE_TRUNC('hour', {col})",
'P1D': "DATE_TRUNC('day', {col})",
'P1W': "DATE_TRUNC('week', {col})",
'P1M': "DATE_TRUNC('month', {col})",
'P0.25Y': "DATE_TRUNC('quarter', {col})",
'P1Y': "DATE_TRUNC('year', {col})",
None: "{col}",
"PT1S": "DATE_TRUNC('second', {col})",
"PT1M": "DATE_TRUNC('minute', {col})",
"PT1H": "DATE_TRUNC('hour', {col})",
"P1D": "DATE_TRUNC('day', {col})",
"P1W": "DATE_TRUNC('week', {col})",
"P1M": "DATE_TRUNC('month', {col})",
"P0.25Y": "DATE_TRUNC('quarter', {col})",
"P1Y": "DATE_TRUNC('year', {col})",
}
@classmethod
@ -49,11 +49,11 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql'
engine = "postgresql"
max_column_name_length = 63
try_remove_schema_from_table_name = False

View File

@ -38,24 +38,22 @@ QueryStatus = utils.QueryStatus
class PrestoEngineSpec(BaseEngineSpec):
engine = 'presto'
engine = "presto"
time_grain_functions = {
None: '{col}',
'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
'P1W/1970-01-03T00:00:00Z':
"date_add('day', 5, date_trunc('week', date_add('day', 1, "
'CAST({col} AS TIMESTAMP))))',
'1969-12-28T00:00:00Z/P1W':
"date_add('day', -1, date_trunc('week', "
"date_add('day', 1, CAST({col} AS TIMESTAMP))))",
None: "{col}",
"PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))",
"PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))",
"PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))",
"P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))",
"P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))",
"P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))",
"P0.25Y": "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
"P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))",
"P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', "
"date_add('day', 1, CAST({col} AS TIMESTAMP))))",
"1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', "
"date_add('day', 1, CAST({col} AS TIMESTAMP))))",
}
@classmethod
@ -76,10 +74,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param data_type: column data type
:return: column info object
"""
return {
'name': name,
'type': f'{data_type}',
}
return {"name": name, "type": f"{data_type}"}
@classmethod
def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
@ -88,7 +83,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param names: list of all individual column names
:return: full column name
"""
return '.'.join(column[0] for column in names if column[0])
return ".".join(column[0] for column in names if column[0])
@classmethod
def _has_nested_data_types(cls, component_type: str) -> bool:
@ -98,10 +93,12 @@ class PrestoEngineSpec(BaseEngineSpec):
:param component_type: data type
:return: boolean
"""
comma_regex = r',(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
white_space_regex = r'\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
return re.search(comma_regex, component_type) is not None \
comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
return (
re.search(comma_regex, component_type) is not None
or re.search(white_space_regex, component_type) is not None
)
@classmethod
def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
@ -114,38 +111,38 @@ class PrestoEngineSpec(BaseEngineSpec):
:return: list of strings after breaking it by the delimiter
"""
return re.split(
r'{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'.format(delimiter), data_type)
r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
)
@classmethod
def _parse_structural_column(cls,
parent_column_name: str,
parent_data_type: str,
result: List[dict]) -> None:
def _parse_structural_column(
cls, parent_column_name: str, parent_data_type: str, result: List[dict]
) -> None:
"""
Parse a row or array column
:param result: list tracking the results
"""
formatted_parent_column_name = parent_column_name
# Quote the column name if there is a space
if ' ' in parent_column_name:
if " " in parent_column_name:
formatted_parent_column_name = f'"{parent_column_name}"'
full_data_type = f'{formatted_parent_column_name} {parent_data_type}'
full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
original_result_len = len(result)
# split on open parenthesis ( to get the structural
# data type and its component types
data_types = cls._split_data_type(full_data_type, r'\(')
data_types = cls._split_data_type(full_data_type, r"\(")
stack: List[Tuple[str, str]] = []
for data_type in data_types:
# split on closed parenthesis ) to track which component
# types belong to what structural data type
inner_types = cls._split_data_type(data_type, r'\)')
inner_types = cls._split_data_type(data_type, r"\)")
for inner_type in inner_types:
# We have finished parsing multiple structural data types
if not inner_type and len(stack) > 0:
stack.pop()
elif cls._has_nested_data_types(inner_type):
# split on comma , to get individual data types
single_fields = cls._split_data_type(inner_type, ',')
single_fields = cls._split_data_type(inner_type, ",")
for single_field in single_fields:
single_field = single_field.strip()
# If component type starts with a comma, the first single field
@ -153,30 +150,37 @@ class PrestoEngineSpec(BaseEngineSpec):
if not single_field:
continue
# split on whitespace to get field name and data type
field_info = cls._split_data_type(single_field, r'\s')
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
if field_info[1] == 'array' or field_info[1] == 'row':
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(cls._create_column_info(
full_parent_path,
presto_type_map[field_info[1]]()))
result.append(
cls._create_column_info(
full_parent_path, presto_type_map[field_info[1]]()
)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
column_name = '{}.{}'.format(full_parent_path, field_info[0])
result.append(cls._create_column_info(
column_name, presto_type_map[field_info[1]]()))
column_name = "{}.{}".format(
full_parent_path, field_info[0]
)
result.append(
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
)
# If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the
# overall structural data type. Otherwise, we have completely parsed
# through the entire structural data type and can move on.
if not (inner_type.endswith('array') or inner_type.endswith('row')):
if not (inner_type.endswith("array") or inner_type.endswith("row")):
stack.pop()
# We have an array of row objects (i.e. array(row(...)))
elif 'array' == inner_type or 'row' == inner_type:
elif "array" == inner_type or "row" == inner_type:
# Push a dummy object to represent the structural data type
stack.append(('', inner_type))
stack.append(("", inner_type))
# We have an array of a basic data types(i.e. array(varchar)).
elif len(stack) > 0:
# Because it is an array of a basic data type. We have finished
@ -185,12 +189,14 @@ class PrestoEngineSpec(BaseEngineSpec):
# Unquote the column name if necessary
if formatted_parent_column_name != parent_column_name:
for index in range(original_result_len, len(result)):
result[index]['name'] = result[index]['name'].replace(
formatted_parent_column_name, parent_column_name)
result[index]["name"] = result[index]["name"].replace(
formatted_parent_column_name, parent_column_name
)
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[RowProxy]:
cls, inspector: Inspector, table_name: str, schema: str
) -> List[RowProxy]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
@ -201,13 +207,14 @@ class PrestoEngineSpec(BaseEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
full_table = '{}.{}'.format(quote(schema), full_table)
columns = inspector.bind.execute('SHOW COLUMNS FROM {}'.format(full_table))
full_table = "{}.{}".format(quote(schema), full_table)
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]:
cls, inspector: Inspector, table_name: str, schema: str
) -> List[dict]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@ -222,22 +229,26 @@ class PrestoEngineSpec(BaseEngineSpec):
for column in columns:
try:
# parse column if it is a row or array
if 'array' in column.Type or 'row' in column.Type:
if "array" in column.Type or "row" in column.Type:
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]['nullable'] = getattr(
column, 'Null', True)
result[structural_column_index]['default'] = None
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue
else: # otherwise column is a basic data type
column_type = presto_type_map[column.Type]()
except KeyError:
logging.info('Did not recognize type {} of column {}'.format(
column.Type, column.Column))
logging.info(
"Did not recognize type {} of column {}".format(
column.Type, column.Column
)
)
column_type = types.NullType
column_info = cls._create_column_info(column.Column, column_type)
column_info['nullable'] = getattr(column, 'Null', True)
column_info['default'] = None
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
result.append(column_info)
return result
@ -258,9 +269,9 @@ class PrestoEngineSpec(BaseEngineSpec):
:return: column clauses
"""
column_clauses = []
# Column names are separated by periods. This regex will find periods in a string
# if they are not enclosed in quotes because if a period is enclosed in quotes,
# then that period is part of a column name.
# Column names are separated by periods. This regex will find periods in a
# string if they are not enclosed in quotes because if a period is enclosed in
# quotes, then that period is part of a column name.
dot_pattern = r"""\. # split on period
(?= # look ahead
(?: # create non-capture group
@ -269,26 +280,28 @@ class PrestoEngineSpec(BaseEngineSpec):
dot_regex = re.compile(dot_pattern, re.VERBOSE)
for col in cols:
# get individual column names
col_names = re.split(dot_regex, col['name'])
col_names = re.split(dot_regex, col["name"])
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
col_names[index] = '"{}"'.format(col_name)
quoted_col_name = '.'.join(
quoted_col_name = ".".join(
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
for col_name in col_names)
for col_name in col_names
)
# create column clause in the format "name"."name" AS "name.name"
column_clause = literal_column(quoted_col_name).label(col['name'])
column_clause = literal_column(quoted_col_name).label(col["name"])
column_clauses.append(column_clause)
return column_clauses
@classmethod
def _filter_out_array_nested_cols(
cls, cols: List[dict]) -> Tuple[List[dict], List[dict]]:
cls, cols: List[dict]
) -> Tuple[List[dict], List[dict]]:
"""
Filter out columns that correspond to array content. We know which columns to
skip because cols is a list provided to us in a specific order where a structural
column is positioned right before its content.
skip because cols is a list provided to us in a specific order where a
structural column is positioned right before its content.
Example: Column Name: ColA, Column Data Type: array(row(nest_obj int))
cols = [ ..., ColA, ColA.nest_obj, ... ]
@ -296,20 +309,21 @@ class PrestoEngineSpec(BaseEngineSpec):
When we run across an array, check if subsequent column names start with the
array name and skip them.
:param cols: columns
:return: filtered list of columns and list of array columns and its nested fields
:return: filtered list of columns and list of array columns and its nested
fields
"""
filtered_cols = []
array_cols = []
curr_array_col_name = None
for col in cols:
# col corresponds to an array's content and should be skipped
if curr_array_col_name and col['name'].startswith(curr_array_col_name):
if curr_array_col_name and col["name"].startswith(curr_array_col_name):
array_cols.append(col)
continue
# col is an array so we need to check if subsequent
# columns correspond to the array's contents
elif str(col['type']) == 'ARRAY':
curr_array_col_name = col['name']
elif str(col["type"]) == "ARRAY":
curr_array_col_name = col["name"]
array_cols.append(col)
filtered_cols.append(col)
else:
@ -318,9 +332,18 @@ class PrestoEngineSpec(BaseEngineSpec):
return filtered_cols, array_cols
@classmethod
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None,
limit: int = 100, show_cols: bool = False, indent: bool = True,
latest_partition: bool = True, cols: List[dict] = []) -> str:
def select_star(
cls,
my_db,
table_name: str,
engine: Engine,
schema: str = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: List[dict] = [],
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
rows, so render the whole array in its own row and skip columns that correspond
@ -328,59 +351,71 @@ class PrestoEngineSpec(BaseEngineSpec):
"""
presto_cols = cols
if show_cols:
dot_regex = r'\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)'
dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
presto_cols = [
col for col in presto_cols if not re.search(dot_regex, col['name'])]
col for col in presto_cols if not re.search(dot_regex, col["name"])
]
return super(PrestoEngineSpec, cls).select_star(
my_db, table_name, engine, schema, limit,
show_cols, indent, latest_partition, presto_cols,
my_db,
table_name,
engine,
schema,
limit,
show_cols,
indent,
latest_partition,
presto_cols,
)
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe='')
if '/' in database:
database = database.split('/')[0] + '/' + selected_schema
selected_schema = parse.quote(selected_schema, safe="")
if "/" in database:
database = database.split("/")[0] + "/" + selected_schema
else:
database += '/' + selected_schema
database += "/" + selected_schema
uri.database = database
return uri
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
if tt == "DATE":
return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
if tt == 'TIMESTAMP':
if tt == "TIMESTAMP":
return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime("%Y-%m-%d %H:%M:%S"))
@classmethod
def epoch_to_dttm(cls):
return 'from_unixtime({col})'
return "from_unixtime({col})"
@classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
def get_all_datasource_names(
cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
datasource_df = db.get_df(
'SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S '
"SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S "
"ORDER BY concat(table_schema, '.', table_name)".format(
datasource_type.upper(),
datasource_type.upper()
),
None)
None,
)
datasource_names: List[utils.DatasourceName] = []
for unused, row in datasource_df.iterrows():
datasource_names.append(utils.DatasourceName(
schema=row['table_schema'], table=row['table_name']))
datasource_names.append(
utils.DatasourceName(
schema=row["table_schema"], table=row["table_name"]
)
)
return datasource_names
@classmethod
def _build_column_hierarchy(cls,
columns: List[dict],
parent_column_types: List[str],
column_hierarchy: dict) -> None:
def _build_column_hierarchy(
cls, columns: List[dict], parent_column_types: List[str], column_hierarchy: dict
) -> None:
"""
Build a graph where the root node represents a column whose data type is in
parent_column_types. A node's children represent that column's nested fields
@ -392,28 +427,30 @@ class PrestoEngineSpec(BaseEngineSpec):
if len(columns) == 0:
return
root = columns.pop(0)
root_info = {'type': root['type'], 'children': []}
column_hierarchy[root['name']] = root_info
root_info = {"type": root["type"], "children": []}
column_hierarchy[root["name"]] = root_info
while columns:
column = columns[0]
# If the column name does not start with the root's name,
# then this column is not a nested field
if not column['name'].startswith(f"{root['name']}."):
if not column["name"].startswith(f"{root['name']}."):
break
# If the column's data type is one of the parent types,
# then this column may have nested fields
if str(column['type']) in parent_column_types:
cls._build_column_hierarchy(columns, parent_column_types,
column_hierarchy)
root_info['children'].append(column['name'])
if str(column["type"]) in parent_column_types:
cls._build_column_hierarchy(
columns, parent_column_types, column_hierarchy
)
root_info["children"].append(column["name"])
continue
else: # The column is a nested field
root_info['children'].append(column['name'])
root_info["children"].append(column["name"])
columns.pop(0)
@classmethod
def _create_row_and_array_hierarchy(
cls, selected_columns: List[dict]) -> Tuple[dict, dict, List[dict]]:
cls, selected_columns: List[dict]
) -> Tuple[dict, dict, List[dict]]:
"""
Build graphs where the root node represents a row or array and its children
are that column's nested fields
@ -425,29 +462,30 @@ class PrestoEngineSpec(BaseEngineSpec):
array_column_hierarchy: OrderedDict = OrderedDict()
expanded_columns: List[dict] = []
for column in selected_columns:
if column['type'].startswith('ROW'):
if column["type"].startswith("ROW"):
parsed_row_columns: List[dict] = []
cls._parse_structural_column(column['name'],
column['type'].lower(),
parsed_row_columns)
cls._parse_structural_column(
column["name"], column["type"].lower(), parsed_row_columns
)
expanded_columns = expanded_columns + parsed_row_columns[1:]
filtered_row_columns, array_columns = cls._filter_out_array_nested_cols(
parsed_row_columns)
cls._build_column_hierarchy(filtered_row_columns,
['ROW'],
row_column_hierarchy)
cls._build_column_hierarchy(array_columns,
['ROW', 'ARRAY'],
array_column_hierarchy)
elif column['type'].startswith('ARRAY'):
parsed_row_columns
)
cls._build_column_hierarchy(
filtered_row_columns, ["ROW"], row_column_hierarchy
)
cls._build_column_hierarchy(
array_columns, ["ROW", "ARRAY"], array_column_hierarchy
)
elif column["type"].startswith("ARRAY"):
parsed_array_columns: List[dict] = []
cls._parse_structural_column(column['name'],
column['type'].lower(),
parsed_array_columns)
cls._parse_structural_column(
column["name"], column["type"].lower(), parsed_array_columns
)
expanded_columns = expanded_columns + parsed_array_columns[1:]
cls._build_column_hierarchy(parsed_array_columns,
['ROW', 'ARRAY'],
array_column_hierarchy)
cls._build_column_hierarchy(
parsed_array_columns, ["ROW", "ARRAY"], array_column_hierarchy
)
return row_column_hierarchy, array_column_hierarchy, expanded_columns
@classmethod
@ -457,7 +495,7 @@ class PrestoEngineSpec(BaseEngineSpec):
:param columns: list of columns
:return: dictionary representing an empty row of data
"""
return {column['name']: '' for column in columns}
return {column["name"]: "" for column in columns}
@classmethod
def _expand_row_data(cls, datum: dict, column: str, column_hierarchy: dict) -> None:
@ -470,22 +508,23 @@ class PrestoEngineSpec(BaseEngineSpec):
"""
if column in datum:
row_data = datum[column]
row_children = column_hierarchy[column]['children']
row_children = column_hierarchy[column]["children"]
if row_data and len(row_data) != len(row_children):
raise Exception('The number of data values and number of nested'
'fields are not equal')
raise Exception(
"The number of data values and number of nested"
"fields are not equal"
)
elif row_data:
for index, data_value in enumerate(row_data):
datum[row_children[index]] = data_value
else:
for row_child in row_children:
datum[row_child] = ''
datum[row_child] = ""
@classmethod
def _split_array_columns_by_process_state(
cls, array_columns: List[str],
array_column_hierarchy: dict,
datum: dict) -> Tuple[List[str], Set[str]]:
cls, array_columns: List[str], array_column_hierarchy: dict, datum: dict
) -> Tuple[List[str], Set[str]]:
"""
Take a list of array columns and split them according to whether or not we are
ready to process them from a data set
@ -501,7 +540,7 @@ class PrestoEngineSpec(BaseEngineSpec):
for array_column in array_columns:
if array_column in datum:
array_columns_to_process.append(array_column)
elif str(array_column_hierarchy[array_column]['type']) == 'ARRAY':
elif str(array_column_hierarchy[array_column]["type"]) == "ARRAY":
child_array = array_column
unprocessed_array_columns.add(child_array)
elif child_array and array_column.startswith(child_array):
@ -510,7 +549,8 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def _convert_data_list_to_array_data_dict(
cls, data: List[dict], array_columns_to_process: List[str]) -> dict:
cls, data: List[dict], array_columns_to_process: List[str]
) -> dict:
"""
Pull out array data from rows of data into a dictionary where the key represents
the index in the data list and the value is the array data values
@ -536,10 +576,9 @@ class PrestoEngineSpec(BaseEngineSpec):
return array_data_dict
@classmethod
def _process_array_data(cls,
data: List[dict],
all_columns: List[dict],
array_column_hierarchy: dict) -> dict:
def _process_array_data(
cls, data: List[dict], all_columns: List[dict], array_column_hierarchy: dict
) -> dict:
"""
Pull out array data that is ready to be processed into a dictionary.
The key refers to the index in the original data set. The value is
@ -575,38 +614,39 @@ class PrestoEngineSpec(BaseEngineSpec):
# Determine what columns are ready to be processed. This is necessary for
# array columns that contain rows with nested arrays. We first process
# the outer arrays before processing inner arrays.
array_columns_to_process, \
unprocessed_array_columns = cls._split_array_columns_by_process_state(
array_columns, array_column_hierarchy, data[0])
array_columns_to_process, unprocessed_array_columns = cls._split_array_columns_by_process_state(
array_columns, array_column_hierarchy, data[0]
)
# Pull out array data that is ready to be processed into a dictionary.
all_array_data = cls._convert_data_list_to_array_data_dict(
data, array_columns_to_process)
data, array_columns_to_process
)
for original_data_index, expanded_array_data in all_array_data.items():
for array_column in array_columns:
if array_column in unprocessed_array_columns:
continue
# Expand array values that are rows
if str(array_column_hierarchy[array_column]['type']) == 'ROW':
if str(array_column_hierarchy[array_column]["type"]) == "ROW":
for array_value in expanded_array_data:
cls._expand_row_data(array_value,
array_column,
array_column_hierarchy)
cls._expand_row_data(
array_value, array_column, array_column_hierarchy
)
continue
array_data = expanded_array_data[0][array_column]
array_children = array_column_hierarchy[array_column]
# This is an empty array of primitive data type
if not array_data and not array_children['children']:
if not array_data and not array_children["children"]:
continue
# Pull out complex array values into its own row of data
elif array_data and array_children['children']:
elif array_data and array_children["children"]:
for array_index, data_value in enumerate(array_data):
if array_index >= len(expanded_array_data):
empty_data = cls._create_empty_row_of_data(all_columns)
expanded_array_data.append(empty_data)
for index, datum_value in enumerate(data_value):
array_child = array_children['children'][index]
array_child = array_children["children"][index]
expanded_array_data[array_index][array_child] = datum_value
# Pull out primitive array values into its own row of data
elif array_data:
@ -617,15 +657,15 @@ class PrestoEngineSpec(BaseEngineSpec):
expanded_array_data[array_index][array_column] = data_value
# This is an empty array with nested fields
else:
for index, array_child in enumerate(array_children['children']):
for index, array_child in enumerate(array_children["children"]):
for array_value in expanded_array_data:
array_value[array_child] = ''
array_value[array_child] = ""
return all_array_data
@classmethod
def _consolidate_array_data_into_data(cls,
data: List[dict],
array_data: dict) -> None:
def _consolidate_array_data_into_data(
cls, data: List[dict], array_data: dict
) -> None:
"""
Consolidate data given a list representing rows of data and a dictionary
representing expanded array data
@ -659,14 +699,14 @@ class PrestoEngineSpec(BaseEngineSpec):
while data_index < len(data):
data[data_index].update(array_data[original_data_index][0])
array_data[original_data_index].pop(0)
data[data_index + 1:data_index + 1] = array_data[original_data_index]
data[data_index + 1 : data_index + 1] = array_data[original_data_index]
data_index = data_index + len(array_data[original_data_index]) + 1
original_data_index = original_data_index + 1
@classmethod
def _remove_processed_array_columns(cls,
unprocessed_array_columns: Set[str],
array_column_hierarchy: dict) -> None:
def _remove_processed_array_columns(
cls, unprocessed_array_columns: Set[str], array_column_hierarchy: dict
) -> None:
"""
Remove keys representing array columns that have already been processed
:param unprocessed_array_columns: list of unprocessed array columns
@ -680,9 +720,9 @@ class PrestoEngineSpec(BaseEngineSpec):
del array_column_hierarchy[array_column]
@classmethod
def expand_data(cls,
columns: List[dict],
data: List[dict]) -> Tuple[List[dict], List[dict], List[dict]]:
def expand_data(
cls, columns: List[dict], data: List[dict]
) -> Tuple[List[dict], List[dict], List[dict]]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
method separates out nested fields and data values to help clearly display
@ -707,18 +747,18 @@ class PrestoEngineSpec(BaseEngineSpec):
all_columns: List[dict] = []
# Get the list of all columns (selected fields and their nested fields)
for column in columns:
if column['type'].startswith('ARRAY') or column['type'].startswith('ROW'):
cls._parse_structural_column(column['name'],
column['type'].lower(),
all_columns)
if column["type"].startswith("ARRAY") or column["type"].startswith("ROW"):
cls._parse_structural_column(
column["name"], column["type"].lower(), all_columns
)
else:
all_columns.append(column)
# Build graphs where the root node is a row or array and its children are that
# column's nested fields
row_column_hierarchy,\
array_column_hierarchy,\
expanded_columns = cls._create_row_and_array_hierarchy(columns)
row_column_hierarchy, array_column_hierarchy, expanded_columns = cls._create_row_and_array_hierarchy(
columns
)
# Pull out a row's nested fields and their values into separate columns
ordered_row_columns = row_column_hierarchy.keys()
@ -729,17 +769,18 @@ class PrestoEngineSpec(BaseEngineSpec):
while array_column_hierarchy:
array_columns = list(array_column_hierarchy.keys())
# Determine what columns are ready to be processed.
array_columns_to_process,\
unprocessed_array_columns = cls._split_array_columns_by_process_state(
array_columns, array_column_hierarchy, data[0])
all_array_data = cls._process_array_data(data,
all_columns,
array_column_hierarchy)
array_columns_to_process, unprocessed_array_columns = cls._split_array_columns_by_process_state(
array_columns, array_column_hierarchy, data[0]
)
all_array_data = cls._process_array_data(
data, all_columns, array_column_hierarchy
)
# Consolidate the original data set and the expanded array data
cls._consolidate_array_data_into_data(data, all_array_data)
# Remove processed array columns from the graph
cls._remove_processed_array_columns(unprocessed_array_columns,
array_column_hierarchy)
cls._remove_processed_array_columns(
unprocessed_array_columns, array_column_hierarchy
)
return all_columns, data, expanded_columns
@ -748,25 +789,26 @@ class PrestoEngineSpec(BaseEngineSpec):
indexes = database.get_indexes(table_name, schema_name)
if not indexes:
return {}
cols = indexes[0].get('column_names', [])
cols = indexes[0].get("column_names", [])
full_table_name = table_name
if schema_name and '.' not in table_name:
full_table_name = '{}.{}'.format(schema_name, table_name)
if schema_name and "." not in table_name:
full_table_name = "{}.{}".format(schema_name, table_name)
pql = cls._partition_query(full_table_name)
col_name, latest_part = cls.latest_partition(
table_name, schema_name, database, show_first=True)
table_name, schema_name, database, show_first=True
)
return {
'partitions': {
'cols': cols,
'latest': {col_name: latest_part},
'partitionQuery': pql,
},
"partitions": {
"cols": cols,
"latest": {col_name: latest_part},
"partitionQuery": pql,
}
}
@classmethod
def handle_cursor(cls, cursor, query, session):
"""Updates progress information"""
logging.info('Polling the cursor for progress')
logging.info("Polling the cursor for progress")
polled = cursor.poll()
# poll returns dict -- JSON status information or ``None``
# if the query is done
@ -774,7 +816,7 @@ class PrestoEngineSpec(BaseEngineSpec):
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
while polled:
# Update the object and wait for the kill signal.
stats = polled.get('stats', {})
stats = polled.get("stats", {})
query = session.query(type(query)).filter_by(id=query.id).one()
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
@ -782,50 +824,51 @@ class PrestoEngineSpec(BaseEngineSpec):
break
if stats:
state = stats.get('state')
state = stats.get("state")
# if already finished, then stop polling
if state == 'FINISHED':
if state == "FINISHED":
break
completed_splits = float(stats.get('completedSplits'))
total_splits = float(stats.get('totalSplits'))
completed_splits = float(stats.get("completedSplits"))
total_splits = float(stats.get("totalSplits"))
if total_splits and completed_splits:
progress = 100 * (completed_splits / total_splits)
logging.info(
'Query progress: {} / {} '
'splits'.format(completed_splits, total_splits))
"Query progress: {} / {} "
"splits".format(completed_splits, total_splits)
)
if progress > query.progress:
query.progress = progress
session.commit()
time.sleep(1)
logging.info('Polling the cursor for progress')
logging.info("Polling the cursor for progress")
polled = cursor.poll()
@classmethod
def extract_error_message(cls, e):
if (
hasattr(e, 'orig') and
type(e.orig).__name__ == 'DatabaseError' and
isinstance(e.orig[0], dict)):
hasattr(e, "orig")
and type(e.orig).__name__ == "DatabaseError"
and isinstance(e.orig[0], dict)
):
error_dict = e.orig[0]
return '{} at {}: {}'.format(
error_dict.get('errorName'),
error_dict.get('errorLocation'),
error_dict.get('message'),
return "{} at {}: {}".format(
error_dict.get("errorName"),
error_dict.get("errorLocation"),
error_dict.get("message"),
)
if (
type(e).__name__ == 'DatabaseError' and
hasattr(e, 'args') and
len(e.args) > 0
type(e).__name__ == "DatabaseError"
and hasattr(e, "args")
and len(e.args) > 0
):
error_dict = e.args[0]
return error_dict.get('message')
return error_dict.get("message")
return utils.error_msg_from_exception(e)
@classmethod
def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
def _partition_query(cls, table_name, limit=0, order_by=None, filters=None):
"""Returns a partition query
:param table_name: the name of the table to get partitions from
@ -838,42 +881,44 @@ class PrestoEngineSpec(BaseEngineSpec):
:type order_by: list of (str, bool) tuples
:param filters: dict of field name and filter value combinations
"""
limit_clause = 'LIMIT {}'.format(limit) if limit else ''
order_by_clause = ''
limit_clause = "LIMIT {}".format(limit) if limit else ""
order_by_clause = ""
if order_by:
l = [] # noqa: E741
for field, desc in order_by:
l.append(field + ' DESC' if desc else '')
order_by_clause = 'ORDER BY ' + ', '.join(l)
l.append(field + " DESC" if desc else "")
order_by_clause = "ORDER BY " + ", ".join(l)
where_clause = ''
where_clause = ""
if filters:
l = [] # noqa: E741
for field, value in filters.items():
l.append(f"{field} = '{value}'")
where_clause = 'WHERE ' + ' AND '.join(l)
where_clause = "WHERE " + " AND ".join(l)
sql = textwrap.dedent(f"""\
sql = textwrap.dedent(
f"""\
SELECT * FROM "{table_name}$partitions"
{where_clause}
{order_by_clause}
{limit_clause}
""")
"""
)
return sql
@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
def where_latest_partition(cls, table_name, schema, database, qry, columns=None):
try:
col_name, value = cls.latest_partition(
table_name, schema, database, show_first=True)
table_name, schema, database, show_first=True
)
except Exception:
# table is not partitioned
return False
if value is not None:
for c in columns:
if c.get('name') == col_name:
if c.get("name") == col_name:
return qry.where(Column(col_name) == value)
return False
@ -900,15 +945,17 @@ class PrestoEngineSpec(BaseEngineSpec):
('ds', '2018-01-01')
"""
indexes = database.get_indexes(table_name, schema)
if len(indexes[0]['column_names']) < 1:
if len(indexes[0]["column_names"]) < 1:
raise SupersetTemplateException(
'The table should have one partitioned field')
elif not show_first and len(indexes[0]['column_names']) > 1:
"The table should have one partitioned field"
)
elif not show_first and len(indexes[0]["column_names"]) > 1:
raise SupersetTemplateException(
'The table should have a single partitioned field '
'to use this function. You may want to use '
'`presto.latest_sub_partition`')
part_field = indexes[0]['column_names'][0]
"The table should have a single partitioned field "
"to use this function. You may want to use "
"`presto.latest_sub_partition`"
)
part_field = indexes[0]["column_names"][0]
sql = cls._partition_query(table_name, 1, [(part_field, True)])
df = database.get_df(sql, schema)
return part_field, cls._latest_partition_from_df(df)
@ -941,15 +988,14 @@ class PrestoEngineSpec(BaseEngineSpec):
'2018-01-01'
"""
indexes = database.get_indexes(table_name, schema)
part_fields = indexes[0]['column_names']
part_fields = indexes[0]["column_names"]
for k in kwargs.keys():
if k not in k in part_fields:
msg = 'Field [{k}] is not part of the portioning key'
msg = "Field [{k}] is not part of the portioning key"
raise SupersetTemplateException(msg)
if len(kwargs.keys()) != len(part_fields) - 1:
msg = (
'A filter needs to be specified for {} out of the '
'{} fields.'
"A filter needs to be specified for {} out of the " "{} fields."
).format(len(part_fields) - 1, len(part_fields))
raise SupersetTemplateException(msg)
@ -957,9 +1003,8 @@ class PrestoEngineSpec(BaseEngineSpec):
if field not in kwargs.keys():
field_to_return = field
sql = cls._partition_query(
table_name, 1, [(field_to_return, True)], kwargs)
sql = cls._partition_query(table_name, 1, [(field_to_return, True)], kwargs)
df = database.get_df(sql, schema)
if df.empty:
return ''
return ""
return df.to_dict()[field_to_return][0]

View File

@ -19,7 +19,7 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
engine = "redshift"
max_column_name_length = 127
@staticmethod

View File

@ -21,38 +21,38 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
engine = "snowflake"
force_column_alias_quotes = True
max_column_name_length = 256
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
'PT1M': "DATE_TRUNC('MINUTE', {col})",
'PT5M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
None: "{col}",
"PT1S": "DATE_TRUNC('SECOND', {col})",
"PT1M": "DATE_TRUNC('MINUTE', {col})",
"PT5M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
DATE_TRUNC('HOUR', {col}))",
'PT10M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
"PT10M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
DATE_TRUNC('HOUR', {col}))",
'PT15M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
"PT15M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
DATE_TRUNC('HOUR', {col}))",
'PT0.5H': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
"PT0.5H": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
DATE_TRUNC('HOUR', {col}))",
'PT1H': "DATE_TRUNC('HOUR', {col})",
'P1D': "DATE_TRUNC('DAY', {col})",
'P1W': "DATE_TRUNC('WEEK', {col})",
'P1M': "DATE_TRUNC('MONTH', {col})",
'P0.25Y': "DATE_TRUNC('QUARTER', {col})",
'P1Y': "DATE_TRUNC('YEAR', {col})",
"PT1H": "DATE_TRUNC('HOUR', {col})",
"P1D": "DATE_TRUNC('DAY', {col})",
"P1W": "DATE_TRUNC('WEEK', {col})",
"P1M": "DATE_TRUNC('MONTH', {col})",
"P0.25Y": "DATE_TRUNC('QUARTER', {col})",
"P1Y": "DATE_TRUNC('YEAR', {col})",
}
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if '/' in uri.database:
database = uri.database.split('/')[0]
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe='')
uri.database = database + '/' + selected_schema
selected_schema = parse.quote(selected_schema, safe="")
uri.database = database + "/" + selected_schema
return uri
@classmethod

View File

@ -22,17 +22,17 @@ from superset.utils import core as utils
class SqliteEngineSpec(BaseEngineSpec):
engine = 'sqlite'
engine = "sqlite"
time_grain_functions = {
None: '{col}',
'PT1H': "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))",
'P1D': 'DATE({col})',
'P1W': "DATE({col}, -strftime('%W', {col}) || ' days')",
'P1M': "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')",
'P1Y': "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))",
'P1W/1970-01-03T00:00:00Z': "DATE({col}, 'weekday 6')",
'1969-12-28T00:00:00Z/P1W': "DATE({col}, 'weekday 0', '-7 days')",
None: "{col}",
"PT1H": "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))",
"P1D": "DATE({col})",
"P1W": "DATE({col}, -strftime('%W', {col}) || ' days')",
"P1M": "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')",
"P1Y": "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))",
"P1W/1970-01-03T00:00:00Z": "DATE({col}, 'weekday 6')",
"1969-12-28T00:00:00Z/P1W": "DATE({col}, 'weekday 0', '-7 days')",
}
@classmethod
@ -40,30 +40,37 @@ class SqliteEngineSpec(BaseEngineSpec):
return "datetime({col}, 'unixepoch')"
@classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
def get_all_datasource_names(
cls, db, datasource_type: str
) -> List[utils.DatasourceName]:
schemas = db.get_all_schema_names(
cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True,
)
schema = schemas[0]
if datasource_type == 'table':
if datasource_type == "table":
return db.get_all_table_names_in_schema(
schema=schema, force=True,
schema=schema,
force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
cache_timeout=db.table_cache_timeout,
)
elif datasource_type == "view":
return db.get_all_view_names_in_schema(
schema=schema, force=True,
schema=schema,
force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
cache_timeout=db.table_cache_timeout,
)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
raise Exception(f"Unsupported datasource_type: {datasource_type}")
@classmethod
def convert_dttm(cls, target_type, dttm):
iso = dttm.isoformat().replace('T', ' ')
if '.' not in iso:
iso += '.000000'
iso = dttm.isoformat().replace("T", " ")
if "." not in iso:
iso += ".000000"
return "'{}'".format(iso)
@classmethod

View File

@ -20,23 +20,26 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
class TeradataEngineSpec(BaseEngineSpec):
"""Dialect for Teradata DB."""
engine = 'teradata'
engine = "teradata"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 30 # since 14.10 this is 128
time_grain_functions = {
None: '{col}',
'PT1M': "TRUNC(CAST({col} as DATE), 'MI')",
'PT1H': "TRUNC(CAST({col} as DATE), 'HH')",
'P1D': "TRUNC(CAST({col} as DATE), 'DDD')",
'P1W': "TRUNC(CAST({col} as DATE), 'WW')",
'P1M': "TRUNC(CAST({col} as DATE), 'MONTH')",
'P0.25Y': "TRUNC(CAST({col} as DATE), 'Q')",
'P1Y': "TRUNC(CAST({col} as DATE), 'YEAR')",
None: "{col}",
"PT1M": "TRUNC(CAST({col} as DATE), 'MI')",
"PT1H": "TRUNC(CAST({col} as DATE), 'HH')",
"P1D": "TRUNC(CAST({col} as DATE), 'DDD')",
"P1W": "TRUNC(CAST({col} as DATE), 'WW')",
"P1M": "TRUNC(CAST({col} as DATE), 'MONTH')",
"P0.25Y": "TRUNC(CAST({col} as DATE), 'Q')",
"P1Y": "TRUNC(CAST({col} as DATE), 'YEAR')",
}
@classmethod
def epoch_to_dttm(cls):
return "CAST(((CAST(DATE '1970-01-01' + ({col} / 86400) AS TIMESTAMP(0) " \
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' " \
'HOUR TO SECOND) AS TIMESTAMP(0))'
return (
"CAST(((CAST(DATE '1970-01-01' + ({col} / 86400) AS TIMESTAMP(0) "
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' "
"HOUR TO SECOND) AS TIMESTAMP(0))"
)

View File

@ -19,4 +19,4 @@ from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica'
engine = "vertica"

View File

@ -18,8 +18,7 @@
# TODO: contribute back to pyhive.
def fetch_logs(self, max_rows=1024,
orientation=None):
def fetch_logs(self, max_rows=1024, orientation=None):
"""Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after
the previous call.
@ -29,18 +28,18 @@ def fetch_logs(self, max_rows=1024,
This is not a part of DB-API.
"""
from pyhive import hive # noqa
from TCLIService import ttypes # noqa
from TCLIService import ttypes # noqa
from thrift import Thrift # pylint: disable=import-error
orientation = orientation or ttypes.TFetchOrientation.FETCH_NEXT
try:
req = ttypes.TGetLogReq(operationHandle=self._operationHandle)
logs = self._connection.client.GetLog(req).log
return logs
# raised if Hive is used
except (ttypes.TApplicationException,
Thrift.TApplicationException):
except (ttypes.TApplicationException, Thrift.TApplicationException):
if self._state == self._STATE_NONE:
raise hive.ProgrammingError('No query yet')
raise hive.ProgrammingError("No query yet")
logs = []
while True:
req = ttypes.TFetchResultsReq(
@ -51,11 +50,10 @@ def fetch_logs(self, max_rows=1024,
)
response = self._connection.client.FetchResults(req)
hive._check_status(response)
assert not response.results.rows, \
'expected data in columnar format'
assert not response.results.rows, "expected data in columnar format"
assert len(response.results.columns) == 1, response.results.columns
new_logs = hive._unwrap_column(response.results.columns[0])
logs += new_logs
if not new_logs:
break
return '\n'.join(logs)
return "\n".join(logs)

View File

@ -36,7 +36,7 @@ def is_subselect(parsed):
if not parsed.is_group():
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() == 'SELECT':
if item.ttype is DML and item.value.upper() == "SELECT":
return True
return False
@ -52,7 +52,7 @@ def extract_from_part(parsed):
raise StopIteration
else:
yield item
elif item.ttype is Keyword and item.value.upper() == 'FROM':
elif item.ttype is Keyword and item.value.upper() == "FROM":
from_seen = True

View File

@ -20,8 +20,7 @@ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.forms import DynamicForm
from flask_babel import lazy_gettext as _
from flask_wtf.file import FileAllowed, FileField, FileRequired
from wtforms import (
BooleanField, Field, IntegerField, SelectField, StringField)
from wtforms import BooleanField, Field, IntegerField, SelectField, StringField
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import DataRequired, Length, NumberRange, Optional
@ -36,13 +35,13 @@ class CommaSeparatedListField(Field):
def _value(self):
if self.data:
return u', '.join(self.data)
return u", ".join(self.data)
else:
return u''
return u""
def process_formdata(self, valuelist):
if valuelist:
self.data = [x.strip() for x in valuelist[0].split(',')]
self.data = [x.strip() for x in valuelist[0].split(",")]
else:
self.data = []
@ -61,9 +60,9 @@ class CsvToDatabaseForm(DynamicForm):
# pylint: disable=E0211
def csv_allowed_dbs():
csv_allowed_dbs = []
csv_enabled_dbs = db.session.query(
models.Database).filter_by(
allow_csv_upload=True).all()
csv_enabled_dbs = (
db.session.query(models.Database).filter_by(allow_csv_upload=True).all()
)
for csv_enabled_db in csv_enabled_dbs:
if CsvToDatabaseForm.at_least_one_schema_is_allowed(csv_enabled_db):
csv_allowed_dbs.append(csv_enabled_db)
@ -95,110 +94,132 @@ class CsvToDatabaseForm(DynamicForm):
b) if database supports schema
user is able to upload to schema in schemas_allowed_for_csv_upload
"""
if (security_manager.database_access(database) or
security_manager.all_datasource_access()):
if (
security_manager.database_access(database)
or security_manager.all_datasource_access()
):
return True
schemas = database.get_schema_access_for_csv_upload()
if (schemas and
security_manager.schemas_accessible_by_user(
database, schemas, False)):
if schemas and security_manager.schemas_accessible_by_user(
database, schemas, False
):
return True
return False
name = StringField(
_('Table Name'),
description=_('Name of table to be created from csv data.'),
_("Table Name"),
description=_("Name of table to be created from csv data."),
validators=[DataRequired()],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
csv_file = FileField(
_('CSV File'),
description=_('Select a CSV file to be uploaded to a database.'),
validators=[
FileRequired(), FileAllowed(['csv'], _('CSV Files Only!'))])
_("CSV File"),
description=_("Select a CSV file to be uploaded to a database."),
validators=[FileRequired(), FileAllowed(["csv"], _("CSV Files Only!"))],
)
con = QuerySelectField(
_('Database'),
_("Database"),
query_factory=csv_allowed_dbs,
get_pk=lambda a: a.id, get_label=lambda a: a.database_name)
get_pk=lambda a: a.id,
get_label=lambda a: a.database_name,
)
schema = StringField(
_('Schema'),
description=_('Specify a schema (if database flavor supports this).'),
_("Schema"),
description=_("Specify a schema (if database flavor supports this)."),
validators=[Optional()],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
sep = StringField(
_('Delimiter'),
description=_('Delimiter used by CSV file (for whitespace use \s+).'),
_("Delimiter"),
description=_("Delimiter used by CSV file (for whitespace use \\s+)."),
validators=[DataRequired()],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
if_exists = SelectField(
_('Table Exists'),
_("Table Exists"),
description=_(
'If table exists do one of the following: '
'Fail (do nothing), Replace (drop and recreate table) '
'or Append (insert data).'),
"If table exists do one of the following: "
"Fail (do nothing), Replace (drop and recreate table) "
"or Append (insert data)."
),
choices=[
('fail', _('Fail')), ('replace', _('Replace')),
('append', _('Append'))],
validators=[DataRequired()])
("fail", _("Fail")),
("replace", _("Replace")),
("append", _("Append")),
],
validators=[DataRequired()],
)
header = IntegerField(
_('Header Row'),
_("Header Row"),
description=_(
'Row containing the headers to use as '
'column names (0 is first line of data). '
'Leave empty if there is no header row.'),
"Row containing the headers to use as "
"column names (0 is first line of data). "
"Leave empty if there is no header row."
),
validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
index_col = IntegerField(
_('Index Column'),
_("Index Column"),
description=_(
'Column to use as the row labels of the '
'dataframe. Leave empty if no index column.'),
"Column to use as the row labels of the "
"dataframe. Leave empty if no index column."
),
validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
mangle_dupe_cols = BooleanField(
_('Mangle Duplicate Columns'),
description=_('Specify duplicate columns as "X.0, X.1".'))
_("Mangle Duplicate Columns"),
description=_('Specify duplicate columns as "X.0, X.1".'),
)
skipinitialspace = BooleanField(
_('Skip Initial Space'),
description=_('Skip spaces after delimiter.'))
_("Skip Initial Space"), description=_("Skip spaces after delimiter.")
)
skiprows = IntegerField(
_('Skip Rows'),
description=_('Number of rows to skip at start of file.'),
_("Skip Rows"),
description=_("Number of rows to skip at start of file."),
validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
nrows = IntegerField(
_('Rows to Read'),
description=_('Number of rows of file to read.'),
_("Rows to Read"),
description=_("Number of rows of file to read."),
validators=[Optional(), NumberRange(min=0)],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
skip_blank_lines = BooleanField(
_('Skip Blank Lines'),
_("Skip Blank Lines"),
description=_(
'Skip blank lines rather than interpreting them '
'as NaN values.'))
"Skip blank lines rather than interpreting them " "as NaN values."
),
)
parse_dates = CommaSeparatedListField(
_('Parse Dates'),
_("Parse Dates"),
description=_(
'A comma separated list of columns that should be '
'parsed as dates.'),
filters=[filter_not_empty_values])
"A comma separated list of columns that should be " "parsed as dates."
),
filters=[filter_not_empty_values],
)
infer_datetime_format = BooleanField(
_('Infer Datetime Format'),
description=_(
'Use Pandas to interpret the datetime format '
'automatically.'))
_("Infer Datetime Format"),
description=_("Use Pandas to interpret the datetime format " "automatically."),
)
decimal = StringField(
_('Decimal Character'),
default='.',
description=_('Character to interpret as decimal point.'),
_("Decimal Character"),
default=".",
description=_("Character to interpret as decimal point."),
validators=[Optional(), Length(min=1, max=1)],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)
index = BooleanField(
_('Dataframe Index'),
description=_('Write dataframe index as a column.'))
_("Dataframe Index"), description=_("Write dataframe index as a column.")
)
index_label = StringField(
_('Column Label(s)'),
_("Column Label(s)"),
description=_(
'Column label for index column(s). If None is given '
'and Dataframe Index is True, Index Names are used.'),
"Column label for index column(s). If None is given "
"and Dataframe Index is True, Index Names are used."
),
validators=[Optional()],
widget=BS3TextFieldWidget())
widget=BS3TextFieldWidget(),
)

View File

@ -31,14 +31,14 @@ from superset import app
config = app.config
BASE_CONTEXT = {
'datetime': datetime,
'random': random,
'relativedelta': relativedelta,
'time': time,
'timedelta': timedelta,
'uuid': uuid,
"datetime": datetime,
"random": random,
"relativedelta": relativedelta,
"time": time,
"timedelta": timedelta,
"uuid": uuid,
}
BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {}))
BASE_CONTEXT.update(config.get("JINJA_CONTEXT_ADDONS", {}))
def url_param(param, default=None):
@ -63,16 +63,16 @@ def url_param(param, default=None):
if request.args.get(param):
return request.args.get(param, default)
# Supporting POST as well as get
if request.form.get('form_data'):
form_data = json.loads(request.form.get('form_data'))
url_params = form_data.get('url_params') or {}
if request.form.get("form_data"):
form_data = json.loads(request.form.get("form_data"))
url_params = form_data.get("url_params") or {}
return url_params.get(param, default)
return default
def current_user_id():
"""The id of the user who is currently logged in"""
if hasattr(g, 'user') and g.user:
if hasattr(g, "user") and g.user:
return g.user.id
@ -88,7 +88,8 @@ def filter_values(column, default=None):
This is useful if:
- you want to use a filter box to filter a query where the name of filter box
column doesn't match the one in the select statement
- you want to have the ability for filter inside the main query for speed purposes
- you want to have the ability for filter inside the main query for speed
purposes
This searches for "filters" and "extra_filters" in form_data for a match
@ -105,19 +106,19 @@ def filter_values(column, default=None):
:return: returns a list of filter values
:type: list
"""
form_data = json.loads(request.form.get('form_data', '{}'))
form_data = json.loads(request.form.get("form_data", "{}"))
return_val = []
for filter_type in ['filters', 'extra_filters']:
for filter_type in ["filters", "extra_filters"]:
if filter_type not in form_data:
continue
for f in form_data[filter_type]:
if f['col'] == column:
if isinstance(f['val'], list):
for v in f['val']:
if f["col"] == column:
if isinstance(f["val"], list):
for v in f["val"]:
return_val.append(v)
else:
return_val.append(f['val'])
return_val.append(f["val"])
if return_val:
return return_val
@ -142,6 +143,7 @@ class BaseTemplateProcessor(object):
and are given access to the ``models.Database`` object and schema
name. For globally available methods use ``@classmethod``.
"""
engine = None
def __init__(self, database=None, query=None, table=None, **kwargs):
@ -153,11 +155,11 @@ class BaseTemplateProcessor(object):
elif table:
self.schema = table.schema
self.context = {
'url_param': url_param,
'current_user_id': current_user_id,
'current_username': current_username,
'filter_values': filter_values,
'form_data': {},
"url_param": url_param,
"current_user_id": current_user_id,
"current_username": current_username,
"filter_values": filter_values,
"form_data": {},
}
self.context.update(kwargs)
self.context.update(BASE_CONTEXT)
@ -183,30 +185,30 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
The methods described here are namespaced under ``presto`` in the
jinja context as in ``SELECT '{{ presto.some_macro_call() }}'``
"""
engine = 'presto'
engine = "presto"
@staticmethod
def _schema_table(table_name, schema):
if '.' in table_name:
schema, table_name = table_name.split('.')
if "." in table_name:
schema, table_name = table_name.split(".")
return table_name, schema
def latest_partition(self, table_name):
table_name, schema = self._schema_table(table_name, self.schema)
return self.database.db_engine_spec.latest_partition(
table_name, schema, self.database)[1]
table_name, schema, self.database
)[1]
def latest_sub_partition(self, table_name, **kwargs):
table_name, schema = self._schema_table(table_name, self.schema)
return self.database.db_engine_spec.latest_sub_partition(
table_name=table_name,
schema=schema,
database=self.database,
**kwargs)
table_name=table_name, schema=schema, database=self.database, **kwargs
)
class HiveTemplateProcessor(PrestoTemplateProcessor):
engine = 'hive'
engine = "hive"
template_processors = {}

View File

@ -20,8 +20,7 @@
def update_time_range(form_data):
"""Move since and until to time_range."""
if 'since' in form_data or 'until' in form_data:
form_data['time_range'] = '{} : {}'.format(
form_data.pop('since', '') or '',
form_data.pop('until', '') or '',
if "since" in form_data or "until" in form_data:
form_data["time_range"] = "{} : {}".format(
form_data.pop("since", "") or "", form_data.pop("until", "") or ""
)

View File

@ -29,16 +29,17 @@ config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
logger = logging.getLogger('alembic.env')
logger = logging.getLogger("alembic.env")
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
from flask import current_app
config.set_main_option('sqlalchemy.url',
current_app.config.get('SQLALCHEMY_DATABASE_URI'))
target_metadata = Base.metadata # pylint: disable=no-member
config.set_main_option(
"sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI")
)
target_metadata = Base.metadata # pylint: disable=no-member
# other values from the config, defined by the needs of env.py,
# can be acquired:
@ -77,32 +78,33 @@ def run_migrations_online():
# when there are no changes to the schema
# reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html
def process_revision_directives(context, revision, directives):
if getattr(config.cmd_opts, 'autogenerate', False):
if getattr(config.cmd_opts, "autogenerate", False):
script = directives[0]
if script.upgrade_ops.is_empty():
directives[:] = []
logger.info('No changes in schema detected.')
logger.info("No changes in schema detected.")
engine = engine_from_config(config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
engine = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connection = engine.connect()
kwargs = {}
if engine.name in ('sqlite', 'mysql'):
kwargs = {
'transaction_per_migration': True,
'transactional_ddl': True,
}
configure_args = current_app.extensions['migrate'].configure_args
if engine.name in ("sqlite", "mysql"):
kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
configure_args = current_app.extensions["migrate"].configure_args
if configure_args:
kwargs.update(configure_args)
context.configure(connection=connection,
target_metadata=target_metadata,
# compare_type=True,
process_revision_directives=process_revision_directives,
**kwargs)
context.configure(
connection=connection,
target_metadata=target_metadata,
# compare_type=True,
process_revision_directives=process_revision_directives,
**kwargs
)
try:
with context.begin_transaction():
@ -110,6 +112,7 @@ def run_migrations_online():
finally:
connection.close()
if context.is_offline_mode():
run_migrations_offline()
else:

View File

@ -25,15 +25,15 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0b1f1ab473c0'
down_revision = '55e910a74826'
revision = "0b1f1ab473c0"
down_revision = "55e910a74826"
def upgrade():
with op.batch_alter_table('query') as batch_op:
batch_op.add_column(sa.Column('extra_json', sa.Text(), nullable=True))
with op.batch_alter_table("query") as batch_op:
batch_op.add_column(sa.Column("extra_json", sa.Text(), nullable=True))
def downgrade():
with op.batch_alter_table('query') as batch_op:
batch_op.drop_column('extra_json')
with op.batch_alter_table("query") as batch_op:
batch_op.drop_column("extra_json")

View File

@ -23,29 +23,30 @@ Create Date: 2018-08-06 14:38:18.965248
"""
# revision identifiers, used by Alembic.
revision = '0c5070e96b57'
down_revision = '7fcdcde0761c'
revision = "0c5070e96b57"
down_revision = "7fcdcde0761c"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.create_table('user_attribute',
sa.Column('created_on', sa.DateTime(), nullable=True),
sa.Column('changed_on', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('welcome_dashboard_id', sa.Integer(), nullable=True),
sa.Column('created_by_fk', sa.Integer(), nullable=True),
sa.Column('changed_by_fk', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['welcome_dashboard_id'], ['dashboards.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_table(
"user_attribute",
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("welcome_dashboard_id", sa.Integer(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["welcome_dashboard_id"], ["dashboards.id"]),
sa.PrimaryKeyConstraint("id"),
)
def downgrade():
op.drop_table('user_attribute')
op.drop_table("user_attribute")

View File

@ -27,42 +27,49 @@ from superset.utils.core import generic_find_constraint_name
import logging
# revision identifiers, used by Alembic.
revision = '1226819ee0e3'
down_revision = '956a063c52b3'
revision = "1226819ee0e3"
down_revision = "956a063c52b3"
naming_convention = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
}
def find_constraint_name(upgrade=True):
cols = {'column_name'} if upgrade else {'datasource_name'}
cols = {"column_name"} if upgrade else {"datasource_name"}
return generic_find_constraint_name(
table='columns', columns=cols, referenced='datasources', db=db)
table="columns", columns=cols, referenced="datasources", db=db
)
def upgrade():
try:
constraint = find_constraint_name()
with op.batch_alter_table("columns",
naming_convention=naming_convention) as batch_op:
with op.batch_alter_table(
"columns", naming_convention=naming_convention
) as batch_op:
if constraint:
batch_op.drop_constraint(constraint, type_="foreignkey")
batch_op.create_foreign_key(
'fk_columns_datasource_name_datasources',
'datasources',
['datasource_name'], ['datasource_name'])
"fk_columns_datasource_name_datasources",
"datasources",
["datasource_name"],
["datasource_name"],
)
except:
logging.warning(
"Could not find or drop constraint on `columns`")
logging.warning("Could not find or drop constraint on `columns`")
def downgrade():
constraint = find_constraint_name(False) or 'fk_columns_datasource_name_datasources'
with op.batch_alter_table("columns",
naming_convention=naming_convention) as batch_op:
constraint = find_constraint_name(False) or "fk_columns_datasource_name_datasources"
with op.batch_alter_table(
"columns", naming_convention=naming_convention
) as batch_op:
batch_op.drop_constraint(constraint, type_="foreignkey")
batch_op.create_foreign_key(
'fk_columns_column_name_datasources',
'datasources',
['column_name'], ['datasource_name'])
"fk_columns_column_name_datasources",
"datasources",
["column_name"],
["datasource_name"],
)

View File

@ -23,16 +23,18 @@ Create Date: 2016-12-06 17:40:40.389652
"""
# revision identifiers, used by Alembic.
revision = '1296d28ec131'
down_revision = '6414e83d82b7'
revision = "1296d28ec131"
down_revision = "6414e83d82b7"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('datasources', sa.Column('params', sa.String(length=1000), nullable=True))
op.add_column(
"datasources", sa.Column("params", sa.String(length=1000), nullable=True)
)
def downgrade():
op.drop_column('datasources', 'params')
op.drop_column("datasources", "params")

View File

@ -23,17 +23,16 @@ Create Date: 2015-12-14 13:37:17.374852
"""
# revision identifiers, used by Alembic.
revision = '12d55656cbca'
down_revision = '55179c7f25c7'
revision = "12d55656cbca"
down_revision = "55179c7f25c7"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('tables', sa.Column('is_featured', sa.Boolean(), nullable=True))
op.add_column("tables", sa.Column("is_featured", sa.Boolean(), nullable=True))
def downgrade():
op.drop_column('tables', 'is_featured')
op.drop_column("tables", "is_featured")

View File

@ -28,15 +28,16 @@ from sqlalchemy.ext.declarative import declarative_base
from superset import db
# revision identifiers, used by Alembic.
revision = '130915240929'
down_revision = 'f231d82b9b26'
revision = "130915240929"
down_revision = "f231d82b9b26"
Base = declarative_base()
class Table(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'tables'
__tablename__ = "tables"
id = sa.Column(sa.Integer, primary_key=True)
sql = sa.Column(sa.Text)
is_sqllab_view = sa.Column(sa.Boolean())
@ -45,9 +46,9 @@ class Table(Base):
def upgrade():
bind = op.get_bind()
op.add_column(
'tables',
"tables",
sa.Column(
'is_sqllab_view',
"is_sqllab_view",
sa.Boolean(),
nullable=True,
default=False,
@ -67,4 +68,4 @@ def upgrade():
def downgrade():
op.drop_column('tables', 'is_sqllab_view')
op.drop_column("tables", "is_sqllab_view")

View File

@ -23,8 +23,8 @@ Create Date: 2019-01-18 14:56:26.307684
"""
# revision identifiers, used by Alembic.
revision = '18dc26817ad2'
down_revision = ('8b70aa3d0f87', 'a33a03f16c4a')
revision = "18dc26817ad2"
down_revision = ("8b70aa3d0f87", "a33a03f16c4a")
from alembic import op
import sqlalchemy as sa

View File

@ -25,106 +25,77 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '18e88e1cc004'
down_revision = '430039611635'
revision = "18e88e1cc004"
down_revision = "430039611635"
def upgrade():
try:
op.alter_column(
'clusters', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"clusters", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'clusters', 'created_on',
existing_type=sa.DATETIME(), nullable=True)
op.drop_constraint(None, 'columns', type_='foreignkey')
op.drop_constraint(None, 'columns', type_='foreignkey')
op.drop_column('columns', 'created_on')
op.drop_column('columns', 'created_by_fk')
op.drop_column('columns', 'changed_on')
op.drop_column('columns', 'changed_by_fk')
"clusters", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.drop_constraint(None, "columns", type_="foreignkey")
op.drop_constraint(None, "columns", type_="foreignkey")
op.drop_column("columns", "created_on")
op.drop_column("columns", "created_by_fk")
op.drop_column("columns", "changed_on")
op.drop_column("columns", "changed_by_fk")
op.alter_column(
'css_templates', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"css_templates", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'css_templates', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"css_templates", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'dashboards', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"dashboards", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'dashboards', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
op.create_unique_constraint(None, 'dashboards', ['slug'])
"dashboards", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.create_unique_constraint(None, "dashboards", ["slug"])
op.alter_column(
'datasources', 'changed_by_fk',
existing_type=sa.INTEGER(),
nullable=True)
"datasources", "changed_by_fk", existing_type=sa.INTEGER(), nullable=True
)
op.alter_column(
'datasources', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"datasources", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'datasources', 'created_by_fk',
existing_type=sa.INTEGER(),
nullable=True)
"datasources", "created_by_fk", existing_type=sa.INTEGER(), nullable=True
)
op.alter_column(
'datasources', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"datasources", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column("dbs", "changed_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column("dbs", "created_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column(
'dbs', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"slices", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'dbs', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"slices", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'slices', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"sql_metrics", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'slices', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"sql_metrics", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'sql_metrics', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"table_columns", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'sql_metrics', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"table_columns", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'table_columns', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
"tables", "changed_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column(
'table_columns', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'tables', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'tables', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'url', 'changed_on',
existing_type=sa.DATETIME(),
nullable=True)
op.alter_column(
'url', 'created_on',
existing_type=sa.DATETIME(),
nullable=True)
"tables", "created_on", existing_type=sa.DATETIME(), nullable=True
)
op.alter_column("url", "changed_on", existing_type=sa.DATETIME(), nullable=True)
op.alter_column("url", "created_on", existing_type=sa.DATETIME(), nullable=True)
except Exception:
pass

View File

@ -23,20 +23,20 @@ Create Date: 2017-09-15 15:09:40.495345
"""
# revision identifiers, used by Alembic.
revision = '19a814813610'
down_revision = 'ca69c70ec99b'
revision = "19a814813610"
down_revision = "ca69c70ec99b"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('metrics', sa.Column('warning_text', sa.Text(), nullable=True))
op.add_column('sql_metrics', sa.Column('warning_text', sa.Text(), nullable=True))
op.add_column("metrics", sa.Column("warning_text", sa.Text(), nullable=True))
op.add_column("sql_metrics", sa.Column("warning_text", sa.Text(), nullable=True))
def downgrade():
with op.batch_alter_table('sql_metrics') as batch_op_sql_metrics:
batch_op_sql_metrics.drop_column('warning_text')
with op.batch_alter_table('metrics') as batch_op_metrics:
batch_op_metrics.drop_column('warning_text')
with op.batch_alter_table("sql_metrics") as batch_op_sql_metrics:
batch_op_sql_metrics.drop_column("warning_text")
with op.batch_alter_table("metrics") as batch_op_metrics:
batch_op_metrics.drop_column("warning_text")

View File

@ -29,14 +29,14 @@ import sqlalchemy as sa
from superset.utils.core import MediumText
# revision identifiers, used by Alembic.
revision = '1a1d627ebd8e'
down_revision = '0c5070e96b57'
revision = "1a1d627ebd8e"
down_revision = "0c5070e96b57"
def upgrade():
with op.batch_alter_table('dashboards') as batch_op:
with op.batch_alter_table("dashboards") as batch_op:
batch_op.alter_column(
'position_json',
"position_json",
existing_type=sa.Text(),
type_=MediumText(),
existing_nullable=True,
@ -44,9 +44,9 @@ def upgrade():
def downgrade():
with op.batch_alter_table('dashboards') as batch_op:
with op.batch_alter_table("dashboards") as batch_op:
batch_op.alter_column(
'position_json',
"position_json",
existing_type=MediumText(),
type_=sa.Text(),
existing_nullable=True,

View File

@ -23,20 +23,21 @@ Create Date: 2015-12-04 09:42:16.973264
"""
# revision identifiers, used by Alembic.
revision = '1a48a5411020'
down_revision = '289ce07647b'
revision = "1a48a5411020"
down_revision = "289ce07647b"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('dashboards', sa.Column('slug', sa.String(length=255), nullable=True))
op.add_column("dashboards", sa.Column("slug", sa.String(length=255), nullable=True))
try:
op.create_unique_constraint('idx_unique_slug', 'dashboards', ['slug'])
op.create_unique_constraint("idx_unique_slug", "dashboards", ["slug"])
except:
pass
def downgrade():
op.drop_constraint(None, 'dashboards', type_='unique')
op.drop_column('dashboards', 'slug')
op.drop_constraint(None, "dashboards", type_="unique")
op.drop_column("dashboards", "slug")

View File

@ -22,16 +22,16 @@ Create Date: 2016-03-25 14:35:44.642576
"""
# revision identifiers, used by Alembic.
revision = '1d2ddd543133'
down_revision = 'd2424a248d63'
revision = "1d2ddd543133"
down_revision = "d2424a248d63"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('logs', sa.Column('dt', sa.Date(), nullable=True))
op.add_column("logs", sa.Column("dt", sa.Date(), nullable=True))
def downgrade():
op.drop_column('logs', 'dt')
op.drop_column("logs", "dt")

View File

@ -26,19 +26,21 @@ import sqlalchemy as sa
from sqlalchemy.sql import expression
# revision identifiers, used by Alembic.
revision = '1d9e835a84f9'
down_revision = '3dda56f1c4c6'
revision = "1d9e835a84f9"
down_revision = "3dda56f1c4c6"
def upgrade():
op.add_column(
'dbs',
"dbs",
sa.Column(
'allow_csv_upload',
"allow_csv_upload",
sa.Boolean(),
nullable=False,
server_default=expression.true()))
server_default=expression.true(),
),
)
def downgrade():
op.drop_column('dbs', 'allow_csv_upload')
op.drop_column("dbs", "allow_csv_upload")

View File

@ -23,15 +23,16 @@ Create Date: 2015-10-05 22:11:00.537054
"""
# revision identifiers, used by Alembic.
revision = '1e2841a4128'
down_revision = '5a7bad26f2a7'
revision = "1e2841a4128"
down_revision = "5a7bad26f2a7"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('table_columns', sa.Column('expression', sa.Text(), nullable=True))
op.add_column("table_columns", sa.Column("expression", sa.Text(), nullable=True))
def downgrade():
op.drop_column('table_columns', 'expression')
op.drop_column("table_columns", "expression")

View File

@ -17,8 +17,7 @@
import json
from alembic import op
from sqlalchemy import (
Column, Integer, or_, String, Text)
from sqlalchemy import Column, Integer, or_, String, Text
from sqlalchemy.ext.declarative import declarative_base
from superset import db
@ -32,14 +31,14 @@ Create Date: 2017-12-17 11:06:30.180267
"""
# revision identifiers, used by Alembic.
revision = '21e88bc06c02'
down_revision = '67a6ac9b727b'
revision = "21e88bc06c02"
down_revision = "67a6ac9b727b"
Base = declarative_base()
class Slice(Base):
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
viz_type = Column(String(250))
params = Column(Text)
@ -49,24 +48,27 @@ def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).filter(or_(
Slice.viz_type.like('line'), Slice.viz_type.like('bar'))):
for slc in session.query(Slice).filter(
or_(Slice.viz_type.like("line"), Slice.viz_type.like("bar"))
):
params = json.loads(slc.params)
layers = params.get('annotation_layers', [])
layers = params.get("annotation_layers", [])
if layers:
new_layers = []
for layer in layers:
new_layers.append({
'annotationType': 'INTERVAL',
'style': 'solid',
'name': 'Layer {}'.format(layer),
'show': True,
'overrides': {'since': None, 'until': None},
'value': layer,
'width': 1,
'sourceType': 'NATIVE',
})
params['annotation_layers'] = new_layers
new_layers.append(
{
"annotationType": "INTERVAL",
"style": "solid",
"name": "Layer {}".format(layer),
"show": True,
"overrides": {"since": None, "until": None},
"value": layer,
"width": 1,
"sourceType": "NATIVE",
}
)
params["annotation_layers"] = new_layers
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
@ -77,12 +79,13 @@ def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).filter(or_(
Slice.viz_type.like('line'), Slice.viz_type.like('bar'))):
for slc in session.query(Slice).filter(
or_(Slice.viz_type.like("line"), Slice.viz_type.like("bar"))
):
params = json.loads(slc.params)
layers = params.get('annotation_layers', [])
layers = params.get("annotation_layers", [])
if layers:
params['annotation_layers'] = [layer['value'] for layer in layers]
params["annotation_layers"] = [layer["value"] for layer in layers]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()

View File

@ -23,20 +23,20 @@ Create Date: 2015-12-15 17:02:45.128709
"""
# revision identifiers, used by Alembic.
revision = '2591d77e9831'
down_revision = '12d55656cbca'
revision = "2591d77e9831"
down_revision = "12d55656cbca"
from alembic import op
import sqlalchemy as sa
def upgrade():
with op.batch_alter_table('tables') as batch_op:
batch_op.add_column(sa.Column('user_id', sa.Integer()))
batch_op.create_foreign_key('user_id', 'ab_user', ['user_id'], ['id'])
with op.batch_alter_table("tables") as batch_op:
batch_op.add_column(sa.Column("user_id", sa.Integer()))
batch_op.create_foreign_key("user_id", "ab_user", ["user_id"], ["id"])
def downgrade():
with op.batch_alter_table('tables') as batch_op:
batch_op.drop_constraint('user_id', type_='foreignkey')
batch_op.drop_column('user_id')
with op.batch_alter_table("tables") as batch_op:
batch_op.drop_constraint("user_id", type_="foreignkey")
batch_op.drop_column("user_id")

View File

@ -23,8 +23,8 @@ Create Date: 2016-06-27 08:43:52.592242
"""
# revision identifiers, used by Alembic.
revision = '27ae655e4247'
down_revision = 'd8bc074f7aad'
revision = "27ae655e4247"
down_revision = "d8bc074f7aad"
from alembic import op
from superset import db
@ -32,41 +32,51 @@ from sqlalchemy.ext.declarative import declarative_base
from flask_appbuilder.models.mixins import AuditMixin
from sqlalchemy.orm import relationship
from flask_appbuilder import Model
from sqlalchemy import (
Column, Integer, ForeignKey, Table)
from sqlalchemy import Column, Integer, ForeignKey, Table
Base = declarative_base()
class User(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'ab_user'
__tablename__ = "ab_user"
id = Column(Integer, primary_key=True)
slice_user = Table('slice_user', Base.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('slice_id', Integer, ForeignKey('slices.id'))
slice_user = Table(
"slice_user",
Base.metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("slice_id", Integer, ForeignKey("slices.id")),
)
dashboard_user = Table(
'dashboard_user', Base.metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('ab_user.id')),
Column('dashboard_id', Integer, ForeignKey('dashboards.id'))
"dashboard_user",
Base.metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("dashboard_id", Integer, ForeignKey("dashboards.id")),
)
class Slice(Base, AuditMixin):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
owners = relationship("User", secondary=slice_user)
class Dashboard(Base, AuditMixin):
"""Declarative class to do query in upgrade"""
__tablename__ = 'dashboards'
__tablename__ = "dashboards"
id = Column(Integer, primary_key=True)
owners = relationship("User", secondary=dashboard_user)
def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)

View File

@ -24,21 +24,18 @@ Create Date: 2015-11-21 11:18:00.650587
from alembic import op
import sqlalchemy as sa
from sqlalchemy_utils import EncryptedType
from sqlalchemy_utils import EncryptedType
# revision identifiers, used by Alembic.
revision = '289ce07647b'
down_revision = '2929af7925ed'
revision = "289ce07647b"
down_revision = "2929af7925ed"
def upgrade():
op.add_column(
'dbs',
sa.Column(
'password',
EncryptedType(sa.String(1024)),
nullable=True))
"dbs", sa.Column("password", EncryptedType(sa.String(1024)), nullable=True)
)
def downgrade():
op.drop_column('dbs', 'password')
op.drop_column("dbs", "password")

View File

@ -23,17 +23,18 @@ Create Date: 2015-10-19 20:54:00.565633
"""
# revision identifiers, used by Alembic.
revision = '2929af7925ed'
down_revision = '1e2841a4128'
revision = "2929af7925ed"
down_revision = "1e2841a4128"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('datasources', sa.Column('offset', sa.Integer(), nullable=True))
op.add_column('tables', sa.Column('offset', sa.Integer(), nullable=True))
op.add_column("datasources", sa.Column("offset", sa.Integer(), nullable=True))
op.add_column("tables", sa.Column("offset", sa.Integer(), nullable=True))
def downgrade():
op.drop_column('tables', 'offset')
op.drop_column('datasources', 'offset')
op.drop_column("tables", "offset")
op.drop_column("datasources", "offset")

View File

@ -25,31 +25,31 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2fcdcb35e487'
down_revision = 'a6c18f869a4e'
revision = "2fcdcb35e487"
down_revision = "a6c18f869a4e"
def upgrade():
op.create_table(
'saved_query',
sa.Column('created_on', sa.DateTime(), nullable=True),
sa.Column('changed_on', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('db_id', sa.Integer(), nullable=True),
sa.Column('label', sa.String(256), nullable=True),
sa.Column('schema', sa.String(128), nullable=True),
sa.Column('sql', sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('changed_by_fk', sa.Integer(), nullable=True),
sa.Column('created_by_fk', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ),
sa.ForeignKeyConstraint(['db_id'], ['dbs.id'], ),
sa.PrimaryKeyConstraint('id')
"saved_query",
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("db_id", sa.Integer(), nullable=True),
sa.Column("label", sa.String(256), nullable=True),
sa.Column("schema", sa.String(128), nullable=True),
sa.Column("sql", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["db_id"], ["dbs.id"]),
sa.PrimaryKeyConstraint("id"),
)
def downgrade():
op.drop_table('saved_query')
op.drop_table("saved_query")

View File

@ -23,8 +23,8 @@ Create Date: 2018-04-08 07:34:12.149910
"""
# revision identifiers, used by Alembic.
revision = '30bb17c0dc76'
down_revision = 'f231d82b9b26'
revision = "30bb17c0dc76"
down_revision = "f231d82b9b26"
from datetime import date
@ -33,10 +33,10 @@ import sqlalchemy as sa
def upgrade():
with op.batch_alter_table('logs') as batch_op:
batch_op.drop_column('dt')
with op.batch_alter_table("logs") as batch_op:
batch_op.drop_column("dt")
def downgrade():
with op.batch_alter_table('logs') as batch_op:
batch_op.add_column(sa.Column('dt', sa.Date, default=date.today()))
with op.batch_alter_table("logs") as batch_op:
batch_op.add_column(sa.Column("dt", sa.Date, default=date.today()))

View File

@ -23,24 +23,25 @@ Create Date: 2015-12-04 11:16:58.226984
"""
# revision identifiers, used by Alembic.
revision = '315b3f4da9b0'
down_revision = '1a48a5411020'
revision = "315b3f4da9b0"
down_revision = "1a48a5411020"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.create_table('logs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('action', sa.String(length=512), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('json', sa.Text(), nullable=True),
sa.Column('dttm', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_table(
"logs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("action", sa.String(length=512), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("json", sa.Text(), nullable=True),
sa.Column("dttm", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)
def downgrade():
op.drop_table('logs')
op.drop_table("logs")

View File

@ -18,8 +18,7 @@ from alembic import op
import sqlalchemy as sa
from superset import db
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import (
Column, Integer, String)
from sqlalchemy import Column, Integer, String
"""update slice model
@ -30,15 +29,16 @@ Create Date: 2016-09-07 23:50:59.366779
"""
# revision identifiers, used by Alembic.
revision = '33d996bcc382'
down_revision = '41f6a59a61f2'
revision = "33d996bcc382"
down_revision = "41f6a59a61f2"
Base = declarative_base()
class Slice(Base):
"""Declarative class to do query in upgrade"""
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer)
druid_datasource_id = Column(Integer)
@ -48,7 +48,7 @@ class Slice(Base):
def upgrade():
bind = op.get_bind()
op.add_column('slices', sa.Column('datasource_id', sa.Integer()))
op.add_column("slices", sa.Column("datasource_id", sa.Integer()))
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
@ -65,11 +65,11 @@ def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).all():
if slc.datasource_type == 'druid':
if slc.datasource_type == "druid":
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == 'table':
if slc.datasource_type == "table":
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column('slices', 'datasource_id')
op.drop_column("slices", "datasource_id")

View File

@ -32,86 +32,101 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = '3b626e2a6783'
down_revision = 'eca4694defa7'
revision = "3b626e2a6783"
down_revision = "eca4694defa7"
def upgrade():
# cleanup after: https://github.com/airbnb/superset/pull/1078
try:
slices_ibfk_1 = generic_find_constraint_name(
table='slices', columns={'druid_datasource_id'},
referenced='datasources', db=db)
table="slices",
columns={"druid_datasource_id"},
referenced="datasources",
db=db,
)
slices_ibfk_2 = generic_find_constraint_name(
table='slices', columns={'table_id'},
referenced='tables', db=db)
table="slices", columns={"table_id"}, referenced="tables", db=db
)
with op.batch_alter_table('slices') as batch_op:
with op.batch_alter_table("slices") as batch_op:
if slices_ibfk_1:
batch_op.drop_constraint(slices_ibfk_1, type_='foreignkey')
batch_op.drop_constraint(slices_ibfk_1, type_="foreignkey")
if slices_ibfk_2:
batch_op.drop_constraint(slices_ibfk_2, type_='foreignkey')
batch_op.drop_column('druid_datasource_id')
batch_op.drop_column('table_id')
batch_op.drop_constraint(slices_ibfk_2, type_="foreignkey")
batch_op.drop_column("druid_datasource_id")
batch_op.drop_column("table_id")
except Exception as e:
logging.warning(str(e))
# fixed issue: https://github.com/airbnb/superset/issues/466
try:
with op.batch_alter_table('columns') as batch_op:
with op.batch_alter_table("columns") as batch_op:
batch_op.create_foreign_key(
None, 'datasources', ['datasource_name'], ['datasource_name'])
None, "datasources", ["datasource_name"], ["datasource_name"]
)
except Exception as e:
logging.warning(str(e))
try:
with op.batch_alter_table('query') as batch_op:
batch_op.create_unique_constraint('client_id', ['client_id'])
with op.batch_alter_table("query") as batch_op:
batch_op.create_unique_constraint("client_id", ["client_id"])
except Exception as e:
logging.warning(str(e))
try:
with op.batch_alter_table('query') as batch_op:
batch_op.drop_column('name')
with op.batch_alter_table("query") as batch_op:
batch_op.drop_column("name")
except Exception as e:
logging.warning(str(e))
def downgrade():
try:
with op.batch_alter_table('tables') as batch_op:
batch_op.create_index('table_name', ['table_name'], unique=True)
with op.batch_alter_table("tables") as batch_op:
batch_op.create_index("table_name", ["table_name"], unique=True)
except Exception as e:
logging.warning(str(e))
try:
with op.batch_alter_table('slices') as batch_op:
batch_op.add_column(sa.Column(
'table_id', mysql.INTEGER(display_width=11),
autoincrement=False, nullable=True))
batch_op.add_column(sa.Column(
'druid_datasource_id', sa.Integer(), autoincrement=False,
nullable=True))
with op.batch_alter_table("slices") as batch_op:
batch_op.add_column(
sa.Column(
"table_id",
mysql.INTEGER(display_width=11),
autoincrement=False,
nullable=True,
)
)
batch_op.add_column(
sa.Column(
"druid_datasource_id",
sa.Integer(),
autoincrement=False,
nullable=True,
)
)
batch_op.create_foreign_key(
'slices_ibfk_1', 'datasources', ['druid_datasource_id'],
['id'])
batch_op.create_foreign_key(
'slices_ibfk_2', 'tables', ['table_id'], ['id'])
"slices_ibfk_1", "datasources", ["druid_datasource_id"], ["id"]
)
batch_op.create_foreign_key("slices_ibfk_2", "tables", ["table_id"], ["id"])
except Exception as e:
logging.warning(str(e))
try:
fk_columns = generic_find_constraint_name(
table='columns', columns={'datasource_name'},
referenced='datasources', db=db)
with op.batch_alter_table('columns') as batch_op:
batch_op.drop_constraint(fk_columns, type_='foreignkey')
table="columns",
columns={"datasource_name"},
referenced="datasources",
db=db,
)
with op.batch_alter_table("columns") as batch_op:
batch_op.drop_constraint(fk_columns, type_="foreignkey")
except Exception as e:
logging.warning(str(e))
op.add_column(
'query', sa.Column('name', sa.String(length=256), nullable=True))
op.add_column("query", sa.Column("name", sa.String(length=256), nullable=True))
try:
with op.batch_alter_table('query') as batch_op:
batch_op.drop_constraint('client_id', type_='unique')
with op.batch_alter_table("query") as batch_op:
batch_op.drop_constraint("client_id", type_="unique")
except Exception as e:
logging.warning(str(e))

View File

@ -23,16 +23,16 @@ Create Date: 2016-08-18 14:06:28.784699
"""
# revision identifiers, used by Alembic.
revision = '3c3ffe173e4f'
down_revision = 'ad82a75afd82'
revision = "3c3ffe173e4f"
down_revision = "ad82a75afd82"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('tables', sa.Column('sql', sa.Text(), nullable=True))
op.add_column("tables", sa.Column("sql", sa.Text(), nullable=True))
def downgrade():
op.drop_column('tables', 'sql')
op.drop_column("tables", "sql")

View File

@ -35,45 +35,41 @@ from sqlalchemy import Column, Integer, String, Text
from superset import db
from superset.utils.core import parse_human_timedelta
revision = '3dda56f1c4c6'
down_revision = 'bddc498dd179'
revision = "3dda56f1c4c6"
down_revision = "bddc498dd179"
Base = declarative_base()
class Slice(Base):
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
datasource_type = Column(String(200))
params = Column(Text)
comparison_type_map = {
'factor': 'ratio',
'growth': 'percentage',
'value': 'absolute',
}
comparison_type_map = {"factor": "ratio", "growth": "percentage", "value": "absolute"}
db_engine_specs_map = {
'second': 'PT1S',
'minute': 'PT1M',
'5 minute': 'PT5M',
'10 minute': 'PT10M',
'half hour': 'PT0.5H',
'hour': 'PT1H',
'day': 'P1D',
'week': 'P1W',
'week_ending_saturday': 'P1W',
'week_start_sunday': 'P1W',
'week_start_monday': 'P1W',
'week_starting_sunday': 'P1W',
'P1W/1970-01-03T00:00:00Z': 'P1W',
'1969-12-28T00:00:00Z/P1W': 'P1W',
'month': 'P1M',
'quarter': 'P0.25Y',
'year': 'P1Y',
"second": "PT1S",
"minute": "PT1M",
"5 minute": "PT5M",
"10 minute": "PT10M",
"half hour": "PT0.5H",
"hour": "PT1H",
"day": "P1D",
"week": "P1W",
"week_ending_saturday": "P1W",
"week_start_sunday": "P1W",
"week_start_monday": "P1W",
"week_starting_sunday": "P1W",
"P1W/1970-01-03T00:00:00Z": "P1W",
"1969-12-28T00:00:00Z/P1W": "P1W",
"month": "P1M",
"quarter": "P0.25Y",
"year": "P1Y",
}
@ -81,41 +77,36 @@ def isodate_duration_to_string(obj):
if obj.tdelta:
if not obj.months and not obj.years:
return format_seconds(obj.tdelta.total_seconds())
raise Exception('Unable to convert: {0}'.format(obj))
raise Exception("Unable to convert: {0}".format(obj))
if obj.months % 12 != 0:
months = obj.months + 12 * obj.years
return '{0} months'.format(months)
return "{0} months".format(months)
return '{0} years'.format(obj.years + obj.months // 12)
return "{0} years".format(obj.years + obj.months // 12)
def timedelta_to_string(obj):
if obj.microseconds:
raise Exception('Unable to convert: {0}'.format(obj))
raise Exception("Unable to convert: {0}".format(obj))
elif obj.seconds:
return format_seconds(obj.total_seconds())
elif obj.days % 7 == 0:
return '{0} weeks'.format(obj.days // 7)
return "{0} weeks".format(obj.days // 7)
else:
return '{0} days'.format(obj.days)
return "{0} days".format(obj.days)
def format_seconds(value):
periods = [
('minute', 60),
('hour', 3600),
('day', 86400),
('week', 604800),
]
periods = [("minute", 60), ("hour", 3600), ("day", 86400), ("week", 604800)]
for period, multiple in periods:
if value % multiple == 0:
value //= multiple
break
else:
period = 'second'
period = "second"
return '{0} {1}{2}'.format(value, period, 's' if value > 1 else '')
return "{0} {1}{2}".format(value, period, "s" if value > 1 else "")
def compute_time_compare(granularity, periods):
@ -129,11 +120,11 @@ def compute_time_compare(granularity, periods):
obj = isodate.parse_duration(granularity) * periods
except isodate.isoerror.ISO8601Error:
# if parse_human_timedelta can parse it, return it directly
delta = '{0} {1}{2}'.format(periods, granularity, 's' if periods > 1 else '')
delta = "{0} {1}{2}".format(periods, granularity, "s" if periods > 1 else "")
obj = parse_human_timedelta(delta)
if obj:
return delta
raise Exception('Unable to parse: {0}'.format(granularity))
raise Exception("Unable to parse: {0}".format(granularity))
if isinstance(obj, isodate.duration.Duration):
return isodate_duration_to_string(obj)
@ -146,21 +137,24 @@ def upgrade():
session = db.Session(bind=bind)
for chart in session.query(Slice):
params = json.loads(chart.params or '{}')
params = json.loads(chart.params or "{}")
if not params.get('num_period_compare'):
if not params.get("num_period_compare"):
continue
num_period_compare = int(params.get('num_period_compare'))
granularity = (params.get('granularity') if chart.datasource_type == 'druid'
else params.get('time_grain_sqla'))
num_period_compare = int(params.get("num_period_compare"))
granularity = (
params.get("granularity")
if chart.datasource_type == "druid"
else params.get("time_grain_sqla")
)
time_compare = compute_time_compare(granularity, num_period_compare)
period_ratio_type = params.get('period_ratio_type') or 'growth'
period_ratio_type = params.get("period_ratio_type") or "growth"
comparison_type = comparison_type_map[period_ratio_type.lower()]
params['time_compare'] = [time_compare]
params['comparison_type'] = comparison_type
params["time_compare"] = [time_compare]
params["comparison_type"] = comparison_type
chart.params = json.dumps(params, sort_keys=True)
session.commit()
@ -172,11 +166,11 @@ def downgrade():
session = db.Session(bind=bind)
for chart in session.query(Slice):
params = json.loads(chart.params or '{}')
params = json.loads(chart.params or "{}")
if 'time_compare' in params or 'comparison_type' in params:
params.pop('time_compare', None)
params.pop('comparison_type', None)
if "time_compare" in params or "comparison_type" in params:
params.pop("time_compare", None)
params.pop("comparison_type", None)
chart.params = json.dumps(params, sort_keys=True)
session.commit()

View File

@ -26,57 +26,63 @@ Create Date: 2018-12-15 12:34:47.228756
from superset import db
from superset.utils.core import generic_find_fk_constraint_name
revision = '3e1b21cd94a4'
down_revision = '6c7537a6004a'
revision = "3e1b21cd94a4"
down_revision = "6c7537a6004a"
from alembic import op
import sqlalchemy as sa
sqlatable_user = sa.Table(
'sqlatable_user', sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')),
sa.Column('table_id', sa.Integer, sa.ForeignKey('tables.id')),
"sqlatable_user",
sa.MetaData(),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
sa.Column("table_id", sa.Integer, sa.ForeignKey("tables.id")),
)
SqlaTable = sa.Table(
'tables', sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')),
"tables",
sa.MetaData(),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
)
druiddatasource_user = sa.Table(
'druiddatasource_user', sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')),
sa.Column('datasource_id', sa.Integer, sa.ForeignKey('datasources.id')),
"druiddatasource_user",
sa.MetaData(),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
sa.Column("datasource_id", sa.Integer, sa.ForeignKey("datasources.id")),
)
DruidDatasource = sa.Table(
'datasources', sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('user_id', sa.Integer, sa.ForeignKey('ab_user.id')),
"datasources",
sa.MetaData(),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("user_id", sa.Integer, sa.ForeignKey("ab_user.id")),
)
def upgrade():
op.create_table('sqlatable_user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('table_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['table_id'], ['tables.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('druiddatasource_user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('datasource_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['datasource_id'], ['datasources.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['ab_user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
"sqlatable_user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("table_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["table_id"], ["tables.id"]),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"druiddatasource_user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("datasource_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["datasource_id"], ["datasources.id"]),
sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)
bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind)
@ -93,29 +99,31 @@ def upgrade():
for druiddatasource in druiddatasources:
if druiddatasource.user_id is not None:
session.execute(
druiddatasource_user.insert().values(user_id=druiddatasource.user_id, datasource_id=druiddatasource.id)
druiddatasource_user.insert().values(
user_id=druiddatasource.user_id, datasource_id=druiddatasource.id
)
)
session.close()
with op.batch_alter_table('tables') as batch_op:
batch_op.drop_constraint('user_id', type_='foreignkey')
batch_op.drop_column('user_id')
with op.batch_alter_table('datasources') as batch_op:
batch_op.drop_constraint(generic_find_fk_constraint_name(
'datasources',
{'id'},
'ab_user',
insp,
), type_='foreignkey')
batch_op.drop_column('user_id')
with op.batch_alter_table("tables") as batch_op:
batch_op.drop_constraint("user_id", type_="foreignkey")
batch_op.drop_column("user_id")
with op.batch_alter_table("datasources") as batch_op:
batch_op.drop_constraint(
generic_find_fk_constraint_name("datasources", {"id"}, "ab_user", insp),
type_="foreignkey",
)
batch_op.drop_column("user_id")
def downgrade():
op.drop_table('sqlatable_user')
op.drop_table('druiddatasource_user')
with op.batch_alter_table('tables') as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
batch_op.create_foreign_key('user_id', 'ab_user', ['user_id'], ['id'])
with op.batch_alter_table('datasources') as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
batch_op.create_foreign_key('fk_datasources_user_id_ab_user', 'ab_user', ['user_id'], ['id'])
op.drop_table("sqlatable_user")
op.drop_table("druiddatasource_user")
with op.batch_alter_table("tables") as batch_op:
batch_op.add_column(sa.Column("user_id", sa.INTEGER(), nullable=True))
batch_op.create_foreign_key("user_id", "ab_user", ["user_id"], ["id"])
with op.batch_alter_table("datasources") as batch_op:
batch_op.add_column(sa.Column("user_id", sa.INTEGER(), nullable=True))
batch_op.create_foreign_key(
"fk_datasources_user_id_ab_user", "ab_user", ["user_id"], ["id"]
)

View File

@ -25,20 +25,19 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '41f6a59a61f2'
down_revision = '3c3ffe173e4f'
revision = "41f6a59a61f2"
down_revision = "3c3ffe173e4f"
def upgrade():
op.add_column('dbs', sa.Column('allow_ctas', sa.Boolean(), nullable=True))
op.add_column("dbs", sa.Column("allow_ctas", sa.Boolean(), nullable=True))
op.add_column("dbs", sa.Column("expose_in_sqllab", sa.Boolean(), nullable=True))
op.add_column(
'dbs', sa.Column('expose_in_sqllab', sa.Boolean(), nullable=True))
op.add_column(
'dbs',
sa.Column('force_ctas_schema', sa.String(length=250), nullable=True))
"dbs", sa.Column("force_ctas_schema", sa.String(length=250), nullable=True)
)
def downgrade():
op.drop_column('dbs', 'force_ctas_schema')
op.drop_column('dbs', 'expose_in_sqllab')
op.drop_column('dbs', 'allow_ctas')
op.drop_column("dbs", "force_ctas_schema")
op.drop_column("dbs", "expose_in_sqllab")
op.drop_column("dbs", "allow_ctas")

View File

@ -25,15 +25,15 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '430039611635'
down_revision = 'd827694c7555'
revision = "430039611635"
down_revision = "d827694c7555"
def upgrade():
op.add_column('logs', sa.Column('dashboard_id', sa.Integer(), nullable=True))
op.add_column('logs', sa.Column('slice_id', sa.Integer(), nullable=True))
op.add_column("logs", sa.Column("dashboard_id", sa.Integer(), nullable=True))
op.add_column("logs", sa.Column("slice_id", sa.Integer(), nullable=True))
def downgrade():
op.drop_column('logs', 'slice_id')
op.drop_column('logs', 'dashboard_id')
op.drop_column("logs", "slice_id")
op.drop_column("logs", "dashboard_id")

View File

@ -23,16 +23,16 @@ Create Date: 2016-01-18 23:43:16.073483
"""
# revision identifiers, used by Alembic.
revision = '43df8de3a5f4'
down_revision = '7dbf98566af7'
revision = "43df8de3a5f4"
down_revision = "7dbf98566af7"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('dashboards', sa.Column('json_metadata', sa.Text(), nullable=True))
op.add_column("dashboards", sa.Column("json_metadata", sa.Text(), nullable=True))
def downgrade():
op.drop_column('dashboards', 'json_metadata')
op.drop_column("dashboards", "json_metadata")

View File

@ -23,8 +23,8 @@ Create Date: 2018-06-13 10:20:35.846744
"""
# revision identifiers, used by Alembic.
revision = '4451805bbaa1'
down_revision = 'bddc498dd179'
revision = "4451805bbaa1"
down_revision = "bddc498dd179"
from alembic import op
@ -38,23 +38,23 @@ Base = declarative_base()
class Slice(Base):
__tablename__ = 'slices'
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer, ForeignKey('tables.id'))
datasource_id = Column(Integer, ForeignKey("tables.id"))
datasource_type = Column(String(200))
params = Column(Text)
class Table(Base):
__tablename__ = 'tables'
__tablename__ = "tables"
id = Column(Integer, primary_key=True)
database_id = Column(Integer, ForeignKey('dbs.id'))
database_id = Column(Integer, ForeignKey("dbs.id"))
class Database(Base):
__tablename__ = 'dbs'
__tablename__ = "dbs"
id = Column(Integer, primary_key=True)
sqlalchemy_uri = Column(String(1024))
@ -68,7 +68,7 @@ def replace(source, target):
session.query(Slice, Database)
.join(Table, Slice.datasource_id == Table.id)
.join(Database, Table.database_id == Database.id)
.filter(Slice.datasource_type == 'table')
.filter(Slice.datasource_type == "table")
.all()
)
@ -79,11 +79,11 @@ def replace(source, target):
if engine.dialect.identifier_preparer._double_percents:
params = json.loads(slc.params)
if 'adhoc_filters' in params:
for filt in params['adhoc_filters']:
if 'sqlExpression' in filt:
filt['sqlExpression'] = (
filt['sqlExpression'].replace(source, target)
if "adhoc_filters" in params:
for filt in params["adhoc_filters"]:
if "sqlExpression" in filt:
filt["sqlExpression"] = filt["sqlExpression"].replace(
source, target
)
slc.params = json.dumps(params, sort_keys=True)
@ -95,8 +95,8 @@ def replace(source, target):
def upgrade():
replace('%%', '%')
replace("%%", "%")
def downgrade():
replace('%', '%%')
replace("%", "%%")

View File

@ -23,22 +23,21 @@ Create Date: 2016-09-12 23:33:14.789632
"""
# revision identifiers, used by Alembic.
revision = '4500485bde7d'
down_revision = '41f6a59a61f2'
revision = "4500485bde7d"
down_revision = "41f6a59a61f2"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('dbs', sa.Column('allow_run_async', sa.Boolean(), nullable=True))
op.add_column('dbs', sa.Column('allow_run_sync', sa.Boolean(), nullable=True))
op.add_column("dbs", sa.Column("allow_run_async", sa.Boolean(), nullable=True))
op.add_column("dbs", sa.Column("allow_run_sync", sa.Boolean(), nullable=True))
def downgrade():
try:
op.drop_column('dbs', 'allow_run_sync')
op.drop_column('dbs', 'allow_run_async')
op.drop_column("dbs", "allow_run_sync")
op.drop_column("dbs", "allow_run_async")
except Exception:
pass

View File

@ -23,8 +23,8 @@ Create Date: 2019-02-16 17:44:44.493427
"""
# revision identifiers, used by Alembic.
revision = '45e7da7cfeba'
down_revision = ('e553e78e90c5', 'c82ee8a39623')
revision = "45e7da7cfeba"
down_revision = ("e553e78e90c5", "c82ee8a39623")
from alembic import op
import sqlalchemy as sa

Some files were not shown because too many files have changed in this diff Show More