Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 138 additions & 20 deletions duckdb/experimental/spark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

class Database(NamedTuple): # noqa: D101
name: str
catalog: str
description: str | None
locationUri: str


class Table(NamedTuple): # noqa: D101
name: str
database: str | None
catalog: str
description: str | None
tableType: str
isTemporary: bool
Expand All @@ -24,56 +26,172 @@ class Column(NamedTuple): # noqa: D101
nullable: bool
isPartition: bool
isBucket: bool
isCluster: bool


class Function(NamedTuple): # noqa: D101
name: str
catalog: str | None
namespace: list[str] | None
description: str | None
className: str
isTemporary: bool


class Catalog: # noqa: D101
class Catalog:
"""Implements the spark catalog API.

Implementation notes:
Spark has the concept of a catalog and inside each catalog there are schemas
which contain tables. But spark calls the schemas as databases through
the catalog API.
For Duckdb, there are databases, which in turn contain schemas. DuckDBs
databases therefore overlap with the concept of the spark catalog.
So to summarize
------------------------------
| Spark | DuckDB |
------------------------------
! Catalog | Database |
| Database/Schema | Schema |
------------------------------
The consequence is that this catalog API refers in several locations to a
database name, which is the DuckDB schema.
"""

def __init__(self, session: SparkSession) -> None: # noqa: D107
self._session = session

def listDatabases(self) -> list[Database]: # noqa: D102
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
def listDatabases(self, pattern: str | None = None) -> list[Database]:
"""Returns a list of database object for all available databases."""
if pattern:
pattern = pattern.replace("*", "%")
where_sql = " WHERE schema_name LIKE ?"
params = (pattern,)
else:
where_sql = ""
params = ()

sql_text = "select schema_name, database_name from duckdb_schemas()" + where_sql
res = self._session.conn.sql(sql_text, params=params).fetchall()

def transform_to_database(x: list[str]) -> Database:
return Database(name=x[0], description=None, locationUri="")
def transform_to_database(x: tuple[str, ...]) -> Database:
return Database(name=x[0], catalog=x[1], description=None, locationUri="")

databases = [transform_to_database(x) for x in res]
return databases

def listTables(self) -> list[Table]: # noqa: D102
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
def listTables(self, dbName: str | None = None, pattern: str | None = None) -> list[Table]:
"""Returns a list of tables/views in the specified database.

If dbName nor pattern are provided, the current active database is used.
"""
dbName = dbName or self.currentDatabase()
current_catalog = self._currentCatalog()
where_sql1 = where_sql2 = ""
params = (current_catalog, dbName)

if pattern:
where_sql1 = " and table_name LIKE ?"
where_sql2 = " and view_name LIKE ?"
params += (pattern.replace("*", "%"),)

sql_text = (
"select database_name, schema_name, table_name, comment, temporary, 'TABLE'"
f"from duckdb_tables() where database_name = ? and schema_name = ?{where_sql1}"
" union all"
" select database_name, schema_name, view_name, comment, temporary, 'VIEW'"
f" from duckdb_views() where database_name = ? and schema_name = ?{where_sql2}"
)

res = self._session.conn.sql(sql_text, params=(*params, *params)).fetchall()

def transform_to_table(x: list[str]) -> Table:
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
return Table(
name=x[2], database=x[1], catalog=x[0], description=x[3], tableType=x[5], isTemporary=bool(x[4])
)

tables = [transform_to_table(x) for x in res]
return tables

def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]: # noqa: D102
query = f"""
select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}'
"""
if dbName:
query += f" and database_name = '{dbName}'"
res = self._session.conn.sql(query).fetchall()
def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]:
"""Returns a list of columns for the given table/view in the specified database."""
query = (
"select column_name, data_type, is_nullable"
" from duckdb_columns()"
" where table_name = ? and schema_name = ? and database_name = ?"
)
dbName = dbName or self.currentDatabase()
params = (tableName, dbName, self._currentCatalog())
res = self._session.conn.sql(query, params=params).fetchall()

if len(res) == 0:
from duckdb.experimental.spark.errors import AnalysisException

msg = f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{tableName}` cannot be found"
raise AnalysisException(msg)

def transform_to_column(x: list[str | bool]) -> Column:
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
return Column(
name=x[0],
description=None,
dataType=x[1],
nullable=x[2],
isPartition=False,
isBucket=False,
isCluster=False,
)

columns = [transform_to_column(x) for x in res]
return columns

def listFunctions(self, dbName: str | None = None) -> list[Function]: # noqa: D102
raise NotImplementedError
def listFunctions(self, dbName: str | None = None, pattern: str | None = None) -> list[Function]:
"""Returns a list of functions registered in the specified database."""
dbName = dbName or self.currentDatabase()
where_sql = ""
params = (dbName,)

if pattern:
pattern = pattern.replace("*", "%")
where_sql = " AND function_name LIKE ?"
params = (pattern,)

sql_text = (
"SELECT DISTINCT database_name, schema_name, function_name, description, function_type"
" FROM duckdb_functions()"
" WHERE schema_name = ? " + where_sql
)

res = self._session.conn.sql(sql_text, params=params).fetchall()

columns = [
Function(
name=x[2],
catalog=x[0],
namespace=[x[1]],
description=x[3],
className=x[4],
isTemporary=x[0] == "temp",
)
for x in res
]
return columns

def currentDatabase(self) -> str:
"""Retrieves the name of the active database/schema."""
res = self._session.conn.sql("SELECT current_schema()").fetchone()
return res[0]

def setCurrentDatabase(self, dbName: str) -> None:
"""Sets the active database/schema. Equivalent to executing 'USE dbName'."""
self._session.conn.sql(f"USE {_sql_quote(dbName)}")

def _currentCatalog(self) -> str:
res = self._session.conn.sql("SELECT current_database()").fetchone()
return res[0]


def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102
raise NotImplementedError
def _sql_quote(value: str) -> str:
return f'"{value.replace('"', '""')}"'


__all__ = ["Catalog", "Column", "Database", "Function", "Table"]
136 changes: 108 additions & 28 deletions tests/fast/spark/test_spark_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,130 @@


class TestSparkCatalog:
def test_list_databases(self, spark):
def test_list_databases_all(self, spark):
dbs = spark.catalog.listDatabases()
if USE_ACTUAL_SPARK:
assert all(isinstance(db, Database) for db in dbs)
else:
assert dbs == [
Database(name="memory", description=None, locationUri=""),
Database(name="system", description=None, locationUri=""),
Database(name="temp", description=None, locationUri=""),
Database(name="main", catalog="memory", description=None, locationUri=""),
Database(name="information_schema", catalog="system", description=None, locationUri=""),
Database(name="main", catalog="system", description=None, locationUri=""),
Database(name="pg_catalog", catalog="system", description=None, locationUri=""),
Database(name="main", catalog="temp", description=None, locationUri=""),
]

def test_list_tables(self, spark):
# empty
def test_create_use_schema(self, spark):
assert spark.catalog.currentDatabase() == "main"

spark.sql("CREATE SCHEMA my_schema1")
spark.catalog.setCurrentDatabase("my_schema1")
assert spark.catalog.currentDatabase() == "my_schema1"

dbs = spark.catalog.listDatabases("*schema1")
assert len(dbs) == 1
assert spark.catalog.currentDatabase() == "my_schema1"

if USE_ACTUAL_SPARK:
return

# Verifying the table goes to the right schema.
spark.sql("create table tbl1(a varchar)")
spark.sql("create table main.tbl2(a varchar)")
expected = [
Table(
name="tbl1",
catalog="memory",
database="my_schema1",
description=None,
tableType="TABLE",
isTemporary=False,
)
]
tbls = spark.catalog.listTables()
assert tbls == expected

spark.sql("DROP TABLE my_schema1.tbl1")
spark.sql("DROP SCHEMA my_schema1")
assert len(spark.catalog.listDatabases("my_schema1")) == 0
assert spark.catalog.currentDatabase() == "main"

@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific databases")
def test_list_databases_pattern(self, spark):
expected = [
Database(name="pg_catalog", catalog="system", description=None, locationUri=""),
]
dbs = spark.catalog.listDatabases("pg*")
assert dbs == expected
dbs = spark.catalog.listDatabases("pg_catalog")
assert dbs == expected
dbs = spark.catalog.listDatabases("notfound")
assert dbs == []

def test_list_tables_empty(self, spark):
tbls = spark.catalog.listTables()
assert tbls == []

if not USE_ACTUAL_SPARK:
# Skip this if we're using actual Spark because we can't create tables
# with our setup.
spark.sql("create table tbl(a varchar)")
tbls = spark.catalog.listTables()
assert tbls == [
Table(
name="tbl",
database="memory",
description="CREATE TABLE tbl(a VARCHAR);",
tableType="",
isTemporary=False,
)
]
@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific tables")
def test_list_tables_create(self, spark):
spark.sql("create table tbl1(a varchar)")
spark.sql("create table tbl2(b varchar); COMMENT ON TABLE tbl2 IS 'hello world'")
expected = [
Table(
name="tbl1", catalog="memory", database="main", description=None, tableType="TABLE", isTemporary=False
),
Table(
name="tbl2",
catalog="memory",
database="main",
description="hello world",
tableType="TABLE",
isTemporary=False,
),
]
tbls = spark.catalog.listTables()
assert tbls == expected

tbls = spark.catalog.listTables(pattern="*l2")
assert tbls == expected[1:]

tbls = spark.catalog.listTables(pattern="tbl2")
assert tbls == expected[1:]

tbls = spark.catalog.listTables(dbName="notfound")
assert tbls == []

spark.sql("create view vw as select * from tbl1")
expected += [
Table(name="vw", catalog="memory", database="main", description=None, tableType="VIEW", isTemporary=False),
]
tbls = spark.catalog.listTables()
assert tbls == expected

@pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup")
def test_list_columns(self, spark):
spark.sql("create table tbl(a varchar, b bool)")
columns = spark.catalog.listColumns("tbl")
assert columns == [
Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False),
Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False),
]

# TODO: should this error instead? # noqa: TD002, TD003
non_existant_columns = spark.catalog.listColumns("none_existant")
assert non_existant_columns == []
columns = spark.catalog.listColumns("tbl")
kwds = dict(description=None, nullable=True, isPartition=False, isBucket=False, isCluster=False) # noqa: C408
assert columns == [Column(name="a", dataType="VARCHAR", **kwds), Column(name="b", dataType="BOOLEAN", **kwds)]

spark.sql("create view vw as select * from tbl")
view_columns = spark.catalog.listColumns("vw")
assert view_columns == columns

from spark_namespace.errors import AnalysisException

with pytest.raises(AnalysisException):
assert spark.catalog.listColumns("tbl", "notfound")

def test_list_columns_not_found(self, spark):
from spark_namespace.errors import AnalysisException

with pytest.raises(AnalysisException):
spark.catalog.listColumns("none_existant")

def test_list_functions(self, spark):
fns = spark.catalog.listFunctions()
assert len(fns)
assert any(f.name == "current_database" for f in fns)