From 80d291e299a71f3a46d07c6cb2fb2243ee92bbed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Eiras?= Date: Wed, 15 Apr 2026 11:24:12 +0200 Subject: [PATCH] Further develop the Catalog API to match Spark better Implemented missing features. Added missing arguments. Corrected database vs schema handling. --- duckdb/experimental/spark/sql/catalog.py | 158 ++++++++++++++++++++--- tests/fast/spark/test_spark_catalog.py | 136 +++++++++++++++---- 2 files changed, 246 insertions(+), 48 deletions(-) diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index f43bab59..3f0bd5fa 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -5,6 +5,7 @@ class Database(NamedTuple): # noqa: D101 name: str + catalog: str description: str | None locationUri: str @@ -12,6 +13,7 @@ class Database(NamedTuple): # noqa: D101 class Table(NamedTuple): # noqa: D101 name: str database: str | None + catalog: str description: str | None tableType: str isTemporary: bool @@ -24,56 +26,172 @@ class Column(NamedTuple): # noqa: D101 nullable: bool isPartition: bool isBucket: bool + isCluster: bool class Function(NamedTuple): # noqa: D101 name: str + catalog: str | None + namespace: list[str] | None description: str | None className: str isTemporary: bool -class Catalog: # noqa: D101 +class Catalog: + """Implements the spark catalog API. + + Implementation notes: + Spark has the concept of a catalog and inside each catalog there are schemas + which contain tables. But spark calls the schemas as databases through + the catalog API. + For Duckdb, there are databases, which in turn contain schemas. DuckDBs + databases therefore overlap with the concept of the spark catalog. + So to summarize + ------------------------------ + | Spark | DuckDB | + ------------------------------ + ! Catalog | Database | + | Database/Schema | Schema | + ------------------------------ + The consequence is that this catalog API refers in several locations to a + database name, which is the DuckDB schema. + """ + def __init__(self, session: SparkSession) -> None: # noqa: D107 self._session = session - def listDatabases(self) -> list[Database]: # noqa: D102 - res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() + def listDatabases(self, pattern: str | None = None) -> list[Database]: + """Returns a list of database object for all available databases.""" + if pattern: + pattern = pattern.replace("*", "%") + where_sql = " WHERE schema_name LIKE ?" + params = (pattern,) + else: + where_sql = "" + params = () + + sql_text = "select schema_name, database_name from duckdb_schemas()" + where_sql + res = self._session.conn.sql(sql_text, params=params).fetchall() - def transform_to_database(x: list[str]) -> Database: - return Database(name=x[0], description=None, locationUri="") + def transform_to_database(x: tuple[str, ...]) -> Database: + return Database(name=x[0], catalog=x[1], description=None, locationUri="") databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> list[Table]: # noqa: D102 - res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() + def listTables(self, dbName: str | None = None, pattern: str | None = None) -> list[Table]: + """Returns a list of tables/views in the specified database. + + If dbName nor pattern are provided, the current active database is used. + """ + dbName = dbName or self.currentDatabase() + current_catalog = self._currentCatalog() + where_sql1 = where_sql2 = "" + params = (current_catalog, dbName) + + if pattern: + where_sql1 = " and table_name LIKE ?" + where_sql2 = " and view_name LIKE ?" + params += (pattern.replace("*", "%"),) + + sql_text = ( + "select database_name, schema_name, table_name, comment, temporary, 'TABLE'" + f"from duckdb_tables() where database_name = ? and schema_name = ?{where_sql1}" + " union all" + " select database_name, schema_name, view_name, comment, temporary, 'VIEW'" + f" from duckdb_views() where database_name = ? and schema_name = ?{where_sql2}" + ) + + res = self._session.conn.sql(sql_text, params=(*params, *params)).fetchall() def transform_to_table(x: list[str]) -> Table: - return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3]) + return Table( + name=x[2], database=x[1], catalog=x[0], description=x[3], tableType=x[5], isTemporary=bool(x[4]) + ) tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]: # noqa: D102 - query = f""" - select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' - """ - if dbName: - query += f" and database_name = '{dbName}'" - res = self._session.conn.sql(query).fetchall() + def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]: + """Returns a list of columns for the given table/view in the specified database.""" + query = ( + "select column_name, data_type, is_nullable" + " from duckdb_columns()" + " where table_name = ? and schema_name = ? and database_name = ?" + ) + dbName = dbName or self.currentDatabase() + params = (tableName, dbName, self._currentCatalog()) + res = self._session.conn.sql(query, params=params).fetchall() + + if len(res) == 0: + from duckdb.experimental.spark.errors import AnalysisException + + msg = f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{tableName}` cannot be found" + raise AnalysisException(msg) def transform_to_column(x: list[str | bool]) -> Column: - return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False) + return Column( + name=x[0], + description=None, + dataType=x[1], + nullable=x[2], + isPartition=False, + isBucket=False, + isCluster=False, + ) columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: str | None = None) -> list[Function]: # noqa: D102 - raise NotImplementedError + def listFunctions(self, dbName: str | None = None, pattern: str | None = None) -> list[Function]: + """Returns a list of functions registered in the specified database.""" + dbName = dbName or self.currentDatabase() + where_sql = "" + params = (dbName,) + + if pattern: + pattern = pattern.replace("*", "%") + where_sql = " AND function_name LIKE ?" + params = (pattern,) + + sql_text = ( + "SELECT DISTINCT database_name, schema_name, function_name, description, function_type" + " FROM duckdb_functions()" + " WHERE schema_name = ? " + where_sql + ) + + res = self._session.conn.sql(sql_text, params=params).fetchall() + + columns = [ + Function( + name=x[2], + catalog=x[0], + namespace=[x[1]], + description=x[3], + className=x[4], + isTemporary=x[0] == "temp", + ) + for x in res + ] + return columns + + def currentDatabase(self) -> str: + """Retrieves the name of the active database/schema.""" + res = self._session.conn.sql("SELECT current_schema()").fetchone() + return res[0] + + def setCurrentDatabase(self, dbName: str) -> None: + """Sets the active database/schema. Equivalent to executing 'USE dbName'.""" + self._session.conn.sql(f"USE {_sql_quote(dbName)}") + + def _currentCatalog(self) -> str: + res = self._session.conn.sql("SELECT current_database()").fetchone() + return res[0] + - def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102 - raise NotImplementedError +def _sql_quote(value: str) -> str: + return f'"{value.replace('"', '""')}"' __all__ = ["Catalog", "Column", "Database", "Function", "Table"] diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 8a07a0a7..dd1387a1 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -7,50 +7,130 @@ class TestSparkCatalog: - def test_list_databases(self, spark): + def test_list_databases_all(self, spark): dbs = spark.catalog.listDatabases() if USE_ACTUAL_SPARK: assert all(isinstance(db, Database) for db in dbs) else: assert dbs == [ - Database(name="memory", description=None, locationUri=""), - Database(name="system", description=None, locationUri=""), - Database(name="temp", description=None, locationUri=""), + Database(name="main", catalog="memory", description=None, locationUri=""), + Database(name="information_schema", catalog="system", description=None, locationUri=""), + Database(name="main", catalog="system", description=None, locationUri=""), + Database(name="pg_catalog", catalog="system", description=None, locationUri=""), + Database(name="main", catalog="temp", description=None, locationUri=""), ] - def test_list_tables(self, spark): - # empty + def test_create_use_schema(self, spark): + assert spark.catalog.currentDatabase() == "main" + + spark.sql("CREATE SCHEMA my_schema1") + spark.catalog.setCurrentDatabase("my_schema1") + assert spark.catalog.currentDatabase() == "my_schema1" + + dbs = spark.catalog.listDatabases("*schema1") + assert len(dbs) == 1 + assert spark.catalog.currentDatabase() == "my_schema1" + + if USE_ACTUAL_SPARK: + return + + # Verifying the table goes to the right schema. + spark.sql("create table tbl1(a varchar)") + spark.sql("create table main.tbl2(a varchar)") + expected = [ + Table( + name="tbl1", + catalog="memory", + database="my_schema1", + description=None, + tableType="TABLE", + isTemporary=False, + ) + ] + tbls = spark.catalog.listTables() + assert tbls == expected + + spark.sql("DROP TABLE my_schema1.tbl1") + spark.sql("DROP SCHEMA my_schema1") + assert len(spark.catalog.listDatabases("my_schema1")) == 0 + assert spark.catalog.currentDatabase() == "main" + + @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific databases") + def test_list_databases_pattern(self, spark): + expected = [ + Database(name="pg_catalog", catalog="system", description=None, locationUri=""), + ] + dbs = spark.catalog.listDatabases("pg*") + assert dbs == expected + dbs = spark.catalog.listDatabases("pg_catalog") + assert dbs == expected + dbs = spark.catalog.listDatabases("notfound") + assert dbs == [] + + def test_list_tables_empty(self, spark): tbls = spark.catalog.listTables() assert tbls == [] - if not USE_ACTUAL_SPARK: - # Skip this if we're using actual Spark because we can't create tables - # with our setup. - spark.sql("create table tbl(a varchar)") - tbls = spark.catalog.listTables() - assert tbls == [ - Table( - name="tbl", - database="memory", - description="CREATE TABLE tbl(a VARCHAR);", - tableType="", - isTemporary=False, - ) - ] + @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Checking duckdb specific tables") + def test_list_tables_create(self, spark): + spark.sql("create table tbl1(a varchar)") + spark.sql("create table tbl2(b varchar); COMMENT ON TABLE tbl2 IS 'hello world'") + expected = [ + Table( + name="tbl1", catalog="memory", database="main", description=None, tableType="TABLE", isTemporary=False + ), + Table( + name="tbl2", + catalog="memory", + database="main", + description="hello world", + tableType="TABLE", + isTemporary=False, + ), + ] + tbls = spark.catalog.listTables() + assert tbls == expected + + tbls = spark.catalog.listTables(pattern="*l2") + assert tbls == expected[1:] + + tbls = spark.catalog.listTables(pattern="tbl2") + assert tbls == expected[1:] + + tbls = spark.catalog.listTables(dbName="notfound") + assert tbls == [] + + spark.sql("create view vw as select * from tbl1") + expected += [ + Table(name="vw", catalog="memory", database="main", description=None, tableType="VIEW", isTemporary=False), + ] + tbls = spark.catalog.listTables() + assert tbls == expected @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_list_columns(self, spark): spark.sql("create table tbl(a varchar, b bool)") - columns = spark.catalog.listColumns("tbl") - assert columns == [ - Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False), - Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False), - ] - # TODO: should this error instead? # noqa: TD002, TD003 - non_existant_columns = spark.catalog.listColumns("none_existant") - assert non_existant_columns == [] + columns = spark.catalog.listColumns("tbl") + kwds = dict(description=None, nullable=True, isPartition=False, isBucket=False, isCluster=False) # noqa: C408 + assert columns == [Column(name="a", dataType="VARCHAR", **kwds), Column(name="b", dataType="BOOLEAN", **kwds)] spark.sql("create view vw as select * from tbl") view_columns = spark.catalog.listColumns("vw") assert view_columns == columns + + from spark_namespace.errors import AnalysisException + + with pytest.raises(AnalysisException): + assert spark.catalog.listColumns("tbl", "notfound") + + def test_list_columns_not_found(self, spark): + from spark_namespace.errors import AnalysisException + + with pytest.raises(AnalysisException): + spark.catalog.listColumns("none_existant") + + def test_list_functions(self, spark): + fns = spark.catalog.listFunctions() + assert len(fns) + assert any(f.name == "current_database" for f in fns)