fix: memoize primitives (#19930)

This commit is contained in:
Beto Dealmeida 2022-05-02 14:50:56 -07:00 committed by GitHub
parent 7b3d0f040b
commit 1ebdaac487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 49 deletions

View File

@ -864,18 +864,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == "table":
all_datasources += database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
all_datasources.extend(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
)
elif datasource_type == "view":
all_datasources += database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
all_datasources.extend(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
)
else:
raise Exception(f"Unsupported datasource_type: {datasource_type}")

View File

@ -81,19 +81,25 @@ class SqliteEngineSpec(BaseEngineSpec):
)
schema = schemas[0]
if datasource_type == "table":
return database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
return [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
]
if datasource_type == "view":
return database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
return [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
]
raise Exception(f"Unsupported datasource_type: {datasource_type}")
@classmethod

View File

@ -522,11 +522,16 @@ class Database(
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.get_all_datasource_names(self, "table")
return [
(datasource_name.table, datasource_name.schema)
for datasource_name in self.db_engine_spec.get_all_datasource_names(
self, "table"
)
]
@cache_util.memoized_func(
key="db:{self.id}:schema:None:view_list",
@ -537,11 +542,16 @@ class Database(
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.get_all_datasource_names(self, "view")
return [
(datasource_name.table, datasource_name.schema)
for datasource_name in self.db_engine_spec.get_all_datasource_names(
self, "view"
)
]
@cache_util.memoized_func(
key="db:{self.id}:schema:{schema}:table_list",
@ -553,7 +563,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@ -569,9 +579,7 @@ class Database(
tables = self.db_engine_spec.get_table_names(
database=self, inspector=self.inspector, schema=schema
)
return [
utils.DatasourceName(table=table, schema=schema) for table in tables
]
return [(table, schema) for table in tables]
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
return []
@ -586,7 +594,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@ -602,7 +610,7 @@ class Database(
views = self.db_engine_spec.get_view_names(
database=self, inspector=self.inspector, schema=schema
)
return [utils.DatasourceName(table=view, schema=schema) for view in views]
return [(view, schema) for view in views]
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
return []

View File

@ -98,7 +98,18 @@ def memoized_func(
key: Optional[str] = None,
cache: Cache = cache_manager.cache,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
"""
Decorator with configurable key and cache backend.
@memoized_func(key="{a}+{b}", cache=cache_manager.data_cache)
def sum(a: int, b: int) -> int:
return a + b
In the example above the result for `1+2` will be stored under the key of name "1+2",
in the `cache_manager.data_cache` cache.
Note: this decorator should be used only with functions that return primitives,
otherwise the deserialization might not work correctly.
enable_cache is treated as True by default,
except enable_cache = False is passed to the decorated function.

View File

@ -1115,31 +1115,37 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True)
if schema_parsed:
tables = (
database.get_all_table_names_in_schema(
tables = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
or []
)
views = (
database.get_all_view_names_in_schema(
] or []
views = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
or []
)
] or []
else:
tables = database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
views = database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
tables = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
]
views = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
]
tables = security_manager.get_datasources_accessible_by_user(
database, tables, schema_parsed
)

View File

@ -46,7 +46,7 @@ def test_get_all_datasource_names_table(app_context: AppContext) -> None:
database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
table_names = ["table1", "table2"]
table_names = [("table1", "schema1"), ("table2", "schema1")]
get_tables = mock.MagicMock(return_value=table_names)
database.get_all_table_names_in_schema = get_tables
result = SqliteEngineSpec.get_all_datasource_names(database, "table")
@ -65,7 +65,7 @@ def test_get_all_datasource_names_view(app_context: AppContext) -> None:
database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
views_names = ["view1", "view2"]
views_names = [("view1", "schema1"), ("view2", "schema1")]
get_views = mock.MagicMock(return_value=views_names)
database.get_all_view_names_in_schema = get_views
result = SqliteEngineSpec.get_all_datasource_names(database, "view")