fix(ssh-tunnel): fix dataset creation flow through modal for DB with tunnel (#22581)

This commit is contained in:
Hugh A. Miles II 2023-01-06 13:52:05 -05:00 committed by GitHub
parent af34e454be
commit d18c7d6128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 11 deletions

View File

@ -744,13 +744,14 @@ class Database(
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
return Table(
table_name,
meta,
schema=schema or None,
autoload=True,
autoload_with=self._get_sqla_engine(),
)
with self.get_sqla_engine_with_context() as engine:
return Table(
table_name,
meta,
schema=schema or None,
autoload=True,
autoload_with=engine,
)
def get_table_comment(
self, table_name: str, schema: Optional[str] = None
@ -846,12 +847,12 @@ class Database(
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
engine = self._get_sqla_engine()
return engine.has_table(table.table_name, table.schema or None)
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table.table_name, table.schema or None)
def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
engine = self._get_sqla_engine()
return engine.has_table(table_name, schema)
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table_name, schema)
@classmethod
def _has_view(

View File

@ -27,10 +27,12 @@ from typing import Dict, List
from urllib.parse import quote
import superset.utils.database
from superset.utils.core import backend
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from sqlalchemy import Table
import pytest
import pytz
@ -79,6 +81,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from tests.integration_tests.conftest import CTAS_SCHEMA_NAME
logger = logging.getLogger(__name__)
@ -1673,6 +1676,16 @@ class TestCore(SupersetTestCase):
)
self.assertRedirects(rv, f"/explore/?form_data_key={random_key}")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_has_table_by_name(self):
if backend() in ("sqlite", "mysql"):
return
example_db = superset.utils.database.get_example_database()
assert (
example_db.has_table_by_name(table_name="birth_names", schema="public")
is True
)
if __name__ == "__main__":
unittest.main()