diff --git a/databricks-mcp-server/databricks_mcp_server/tools/genie_space_builder.py b/databricks-mcp-server/databricks_mcp_server/tools/genie_space_builder.py new file mode 100644 index 00000000..5fa9d43a --- /dev/null +++ b/databricks-mcp-server/databricks_mcp_server/tools/genie_space_builder.py @@ -0,0 +1,564 @@ +"""Builder for Databricks Genie Space ``serialized_space`` payloads. + +Provides a typed authoring API over the JSON envelope that backs a Genie Space. +It covers every slot the Genie Space API accepts at create/import time: + +* ``data_sources.tables`` and ``data_sources.metric_views`` (including nested + ``column_configs``) +* ``config.sample_questions`` +* ``instructions.text_instructions`` +* ``instructions.example_question_sqls`` +* ``instructions.join_specs`` +* ``instructions.sql_snippets.{filters,expressions,measures}`` +* ``benchmarks.questions`` + +The builder is intentionally a thin layer — no network calls, no LLM +dependencies. Pair it with the ``manage_genie`` MCP tool (``create_or_update`` +or ``import`` actions) to push the resulting payload to a workspace. See +``databricks-skills/databricks-genie/spaces-authoring.md`` for a full +authoring walkthrough. + +Round-trip behaviour: unknown fields on loaded payloads are preserved, so +existing spaces can be fetched via ``manage_genie(action="export")``, patched +with this builder, and sent back via ``manage_genie(action="import")`` without +losing data the builder does not model. +""" + +from __future__ import annotations + +import copy +import json +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + + +class GenieSpaceBuilder: + """Typed builder for ``serialized_space`` JSON payloads. + + Usage:: + + builder = GenieSpaceBuilder( + title="Sales Analytics", + description="Explore sales data", + warehouse_id="abc123", + ) + builder.add_table("main.sales.orders", description="Order facts") + builder.add_sample_question("What were total sales last month?") + builder.add_join_spec( + left_identifier="main.sales.orders", + right_identifier="main.sales.customers", + condition="orders.customer_id = customers.customer_id", + relationship_type="MANY_TO_ONE", + ) + envelope = builder.to_envelope() + # envelope ready for manage_genie(action="import", **envelope) + """ + + # ------------------------------------------------------------------ paths + # Each constant is a tuple of nested keys into the ``serialized_space`` + # dict. Using tuples (rather than dotted strings) avoids ambiguity when + # column or table names contain dots. + + TABLES_PATH: Tuple[str, ...] = ("data_sources", "tables") + METRIC_VIEWS_PATH: Tuple[str, ...] = ("data_sources", "metric_views") + SAMPLE_QUESTIONS_PATH: Tuple[str, ...] = ("config", "sample_questions") + TEXT_INSTRUCTIONS_PATH: Tuple[str, ...] = ("instructions", "text_instructions") + EXAMPLE_QUESTION_SQLS_PATH: Tuple[str, ...] = ("instructions", "example_question_sqls") + JOIN_SPECS_PATH: Tuple[str, ...] = ("instructions", "join_specs") + SQL_SNIPPETS_FILTERS_PATH: Tuple[str, ...] = ("instructions", "sql_snippets", "filters") + SQL_SNIPPETS_EXPRESSIONS_PATH: Tuple[str, ...] = ("instructions", "sql_snippets", "expressions") + SQL_SNIPPETS_MEASURES_PATH: Tuple[str, ...] = ("instructions", "sql_snippets", "measures") + BENCHMARKS_PATH: Tuple[str, ...] = ("benchmarks", "questions") + + #: Paths that store a list of ``{"id": ..., ...}`` entries. + ID_LIST_PATHS: Tuple[Tuple[str, ...], ...] = ( + SAMPLE_QUESTIONS_PATH, + TEXT_INSTRUCTIONS_PATH, + EXAMPLE_QUESTION_SQLS_PATH, + JOIN_SPECS_PATH, + SQL_SNIPPETS_FILTERS_PATH, + SQL_SNIPPETS_EXPRESSIONS_PATH, + SQL_SNIPPETS_MEASURES_PATH, + BENCHMARKS_PATH, + ) + + # ------------------------------------------------------------------ init + def __init__( + self, + title: str = "", + description: str = "", + warehouse_id: str = "", + space: Optional[Dict[str, Any]] = None, + ) -> None: + self.title = title + self.description = description + self.warehouse_id = warehouse_id + self._space: Dict[str, Any] = copy.deepcopy(space) if space is not None else {} + self._space.setdefault("version", 2) + + # ----------------------------------------------------------- round-trip + @classmethod + def from_json( + cls, + serialized_space: Union[str, Dict[str, Any]], + title: Optional[str] = None, + description: Optional[str] = None, + warehouse_id: Optional[str] = None, + ) -> "GenieSpaceBuilder": + """Load a builder from a ``serialized_space`` JSON string, dict, or export envelope. + + Accepts: + * the raw payload (``{"version": 2, ...}``) as a dict + * the JSON-encoded string form of the same + * the envelope returned by ``manage_genie(action="export")`` + (``{"title": ..., "serialized_space": "..."}``) + """ + t, d, w = title, description, warehouse_id + + if isinstance(serialized_space, dict) and "serialized_space" in serialized_space: + envelope = serialized_space + t = t if t is not None else envelope.get("title", "") + d = d if d is not None else envelope.get("description", "") + w = w if w is not None else envelope.get("warehouse_id", "") + serialized_space = envelope["serialized_space"] + + if isinstance(serialized_space, str): + serialized_space = json.loads(serialized_space) + + if not isinstance(serialized_space, dict): + raise TypeError(f"serialized_space must be a dict or JSON string, got {type(serialized_space).__name__}") + + return cls( + title=t or "", + description=d or "", + warehouse_id=w or "", + space=serialized_space, + ) + + def to_dict(self) -> Dict[str, Any]: + """Return a deep copy of the ``serialized_space`` dict, normalised for emit. + + Normalisation enforces the sort constraints the Genie API expects: + ``data_sources.tables`` and ``data_sources.metric_views`` are sorted by + ``identifier``; ``column_configs`` within each are sorted by + ``column_name``; ``id``-keyed lists (sample questions, instructions, + snippets, benchmarks, joins) are sorted by ``id``. + """ + out = copy.deepcopy(self._space) + self._normalize(out) + return out + + def to_json(self, *, indent: Optional[int] = None) -> str: + """Return the normalised ``serialized_space`` as a JSON string.""" + return json.dumps(self.to_dict(), indent=indent, sort_keys=False) + + @classmethod + def _normalize(cls, space: Dict[str, Any]) -> None: + """In-place: sort lists per Genie API requirements before emit.""" + ds = space.get("data_sources") or {} + for key in ("tables", "metric_views"): + entries = ds.get(key) + if isinstance(entries, list): + entries.sort(key=lambda e: e.get("identifier", "")) + for entry in entries: + cc = entry.get("column_configs") + if isinstance(cc, list): + cc.sort(key=lambda c: c.get("column_name", "")) + + for path in cls.ID_LIST_PATHS: + node: Any = space + for key in path[:-1]: + if not isinstance(node, dict): + node = None + break + node = node.get(key) + if isinstance(node, dict): + lst = node.get(path[-1]) + if isinstance(lst, list): + lst.sort(key=lambda e: e.get("id", "")) + + def to_envelope(self, *, indent: Optional[int] = None) -> Dict[str, str]: + """Return the full envelope used by ``manage_genie(action="import")``. + + Keys: ``title``, ``description``, ``warehouse_id``, ``serialized_space`` + (the last is a JSON-encoded string, as the API expects). + """ + return { + "title": self.title, + "description": self.description, + "warehouse_id": self.warehouse_id, + "serialized_space": self.to_json(indent=indent), + } + + # ------------------------------------------------------- internal utils + def _get_list(self, path: Tuple[str, ...]) -> List[Dict[str, Any]]: + """Return the list at ``path``, creating parent dicts as needed.""" + if not path: + raise ValueError("path must not be empty") + node: Any = self._space + for key in path[:-1]: + if not isinstance(node, dict): + raise TypeError(f"Cannot traverse into non-dict at path segment {key!r}") + node = node.setdefault(key, {}) + if not isinstance(node, dict): + raise TypeError(f"Cannot set list on non-dict at path {path!r}") + leaf = node.setdefault(path[-1], []) + if not isinstance(leaf, list): + raise TypeError(f"Expected list at {path!r}, found {type(leaf).__name__}") + return leaf + + @staticmethod + def _gen_id() -> str: + """Return a 32-char hex UUID (matches the Genie API format).""" + return uuid4().hex + + @staticmethod + def _as_str_list(value: Union[str, Iterable[str], None]) -> List[str]: + """Coerce a scalar string or iterable-of-strings into a list of strings.""" + if value is None or value == "": + return [] + if isinstance(value, str): + return [value] + return list(value) + + # ---------------------------------------------- data_sources (tables) + def add_table( + self, + identifier: str, + description: Union[str, Iterable[str], None] = None, + column_configs: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Add a Unity Catalog table to the space. + + ``identifier`` must be fully qualified (``catalog.schema.table``). + ``column_configs`` is an optional list of dicts with shape: + ``{"column_name": ..., "description": [...], "synonyms": [...], + "enable_format_assistance": bool, "enable_entity_matching": bool, + "exclude": bool}``. Entries are sorted by ``column_name``. + """ + entry: Dict[str, Any] = { + "identifier": identifier, + "description": self._as_str_list(description), + } + if column_configs: + entry["column_configs"] = sorted(column_configs, key=lambda c: c.get("column_name", "")) + self._get_list(self.TABLES_PATH).append(entry) + return entry + + def add_metric_view( + self, + identifier: str, + description: Union[str, Iterable[str], None] = None, + column_configs: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Add a Unity Catalog metric view to the space.""" + entry: Dict[str, Any] = { + "identifier": identifier, + "description": self._as_str_list(description), + } + if column_configs: + entry["column_configs"] = sorted(column_configs, key=lambda c: c.get("column_name", "")) + self._get_list(self.METRIC_VIEWS_PATH).append(entry) + return entry + + def add_column_config( + self, + table_identifier: str, + column_name: str, + *, + description: Union[str, Iterable[str], None] = None, + synonyms: Optional[Iterable[str]] = None, + enable_format_assistance: Optional[bool] = None, + enable_entity_matching: Optional[bool] = None, + exclude: bool = False, + ) -> Dict[str, Any]: + """Add or replace a column_config on a table or metric_view entry. + + Looks up the data-source entry by ``identifier`` across both + ``tables`` and ``metric_views`` lists. + Existing entries for ``column_name`` are replaced (not duplicated). + The resulting ``column_configs`` list is kept sorted by column name, + which is what the Genie API expects. + """ + for path in (self.TABLES_PATH, self.METRIC_VIEWS_PATH): + for entry in self._get_list(path): + if entry.get("identifier") != table_identifier: + continue + cc: Dict[str, Any] = {"column_name": column_name} + if exclude: + cc["exclude"] = True + else: + if description: + cc["description"] = self._as_str_list(description) + if synonyms: + cc["synonyms"] = list(synonyms) + if enable_format_assistance is not None: + cc["enable_format_assistance"] = bool(enable_format_assistance) + if enable_entity_matching is not None: + cc["enable_entity_matching"] = bool(enable_entity_matching) + configs = entry.setdefault("column_configs", []) + configs[:] = [c for c in configs if c.get("column_name") != column_name] + configs.append(cc) + configs.sort(key=lambda c: c.get("column_name", "")) + return cc + raise KeyError(f"No table or metric_view found with identifier: {table_identifier!r}") + + # -------------------------------------------- config.sample_questions + def add_sample_question(self, question: str, *, id: Optional[str] = None) -> Dict[str, Any]: + """Add a sample question (natural language only; no SQL).""" + entry = {"id": id or self._gen_id(), "question": [question]} + self._get_list(self.SAMPLE_QUESTIONS_PATH).append(entry) + return entry + + # ----------------------------------- instructions.text_instructions + def add_text_instruction(self, content: str, *, id: Optional[str] = None) -> Dict[str, Any]: + """Add a free-form text instruction (markdown supported).""" + entry = {"id": id or self._gen_id(), "content": [content]} + self._get_list(self.TEXT_INSTRUCTIONS_PATH).append(entry) + return entry + + # ----------------------------- instructions.example_question_sqls + def add_example_sql(self, question: str, sql: str, *, id: Optional[str] = None) -> Dict[str, Any]: + """Add a certified question/SQL pair. + + Example SQL is the strongest lever for steering Genie's generations — + each entry pins a canonical SQL answer for a canonical question. + """ + entry = {"id": id or self._gen_id(), "question": [question], "sql": [sql]} + self._get_list(self.EXAMPLE_QUESTION_SQLS_PATH).append(entry) + return entry + + # -------------------------------------- instructions.join_specs + #: Valid relationship types per the Genie API protobuf. + RELATIONSHIP_TYPES = ("ONE_TO_ONE", "ONE_TO_MANY", "MANY_TO_ONE", "MANY_TO_MANY") + + def add_join_spec( + self, + left_identifier: str, + right_identifier: str, + condition: str, + *, + left_alias: str = "", + right_alias: str = "", + relationship_type: str = "", + comment: str = "", + id: Optional[str] = None, + ) -> Dict[str, Any]: + """Add a join definition used by Genie when queries span tables. + + ``left_identifier`` and ``right_identifier`` are fully-qualified table + or metric-view identifiers (``catalog.schema.table``). ``condition`` is + the join predicate; reference columns via the table alias when one is + given. ``relationship_type``, when provided, is encoded as a marker in + the ``sql`` list — the format the API uses to round-trip the + relationship — and must be one of :attr:`RELATIONSHIP_TYPES`. + """ + if relationship_type and relationship_type not in self.RELATIONSHIP_TYPES: + raise ValueError(f"relationship_type must be one of {self.RELATIONSHIP_TYPES}, got {relationship_type!r}") + + left: Dict[str, Any] = {"identifier": left_identifier} + if left_alias: + left["alias"] = left_alias + right: Dict[str, Any] = {"identifier": right_identifier} + if right_alias: + right["alias"] = right_alias + + sql_parts: List[str] = [condition] + if relationship_type: + sql_parts.append(f"--rt=FROM_RELATIONSHIP_TYPE_{relationship_type}--") + + entry: Dict[str, Any] = { + "id": id or self._gen_id(), + "left": left, + "right": right, + "sql": sql_parts, + } + if comment: + entry["comment"] = self._as_str_list(comment) + self._get_list(self.JOIN_SPECS_PATH).append(entry) + return entry + + # ------------------------------- instructions.sql_snippets.filters + def add_sql_filter( + self, + sql: str, + *, + display_name: str = "", + synonyms: Optional[Iterable[str]] = None, + comment: str = "", + id: Optional[str] = None, + ) -> Dict[str, Any]: + """Add a named, reusable WHERE-style predicate (e.g. status filters). + + ``sql`` is the predicate body without ``WHERE`` (e.g. + ``"category_l3 = 'Dresses'"``). + """ + return self._add_snippet( + self.SQL_SNIPPETS_FILTERS_PATH, + id=id, + sql=sql, + display_name=display_name, + synonyms=synonyms, + comment=comment, + ) + + # --------------------------- instructions.sql_snippets.expressions + def add_sql_expression( + self, + alias: str, + sql: str, + *, + display_name: str = "", + synonyms: Optional[Iterable[str]] = None, + comment: str = "", + id: Optional[str] = None, + ) -> Dict[str, Any]: + """Add a named, reusable SELECT-list expression (date extracts, CASE bucketing). + + ``alias`` is the projected column name (e.g. ``"snapshot_year"``). + """ + return self._add_snippet( + self.SQL_SNIPPETS_EXPRESSIONS_PATH, + id=id, + alias=alias, + sql=sql, + display_name=display_name, + synonyms=synonyms, + comment=comment, + ) + + # ----------------------------- instructions.sql_snippets.measures + def add_sql_measure( + self, + alias: str, + sql: str, + *, + display_name: str = "", + synonyms: Optional[Iterable[str]] = None, + comment: str = "", + id: Optional[str] = None, + ) -> Dict[str, Any]: + """Add a named, reusable aggregate (e.g. ``SUM(orders.total_amount)``). + + ``alias`` is the measure identifier (the API field is ``alias``, not + ``name``). + """ + return self._add_snippet( + self.SQL_SNIPPETS_MEASURES_PATH, + id=id, + alias=alias, + sql=sql, + display_name=display_name, + synonyms=synonyms, + comment=comment, + ) + + #: Snippet fields whose values are stored as ``[str]`` lists per the Genie API. + _SNIPPET_LIST_FIELDS = frozenset({"sql", "comment"}) + + def _add_snippet( + self, + path: Tuple[str, ...], + *, + id: Optional[str], + **fields: Any, + ) -> Dict[str, Any]: + """Shared implementation for filters / expressions / measures snippets.""" + entry: Dict[str, Any] = {"id": id or self._gen_id()} + for key, value in fields.items(): + if value is None or value == "": + continue + if key == "synonyms": + entry[key] = list(value) + elif key in self._SNIPPET_LIST_FIELDS: + entry[key] = self._as_str_list(value) + else: + entry[key] = value + self._get_list(path).append(entry) + return entry + + # --------------------------------------------- benchmarks.questions + def add_benchmark(self, question: str, sql: str, *, id: Optional[str] = None) -> Dict[str, Any]: + """Add a benchmark question with a canonical SQL answer. + + Benchmarks are used by the Genie quality evaluator; each entry pairs a + question with one or more SQL bodies treated as ground truth. The wire + format is ``{"answer": [{"format": "SQL", "content": ["SELECT ..."]}]}``. + """ + entry = { + "id": id or self._gen_id(), + "question": [question], + "answer": [{"format": "SQL", "content": [sql]}], + } + self._get_list(self.BENCHMARKS_PATH).append(entry) + return entry + + # ---------------------------------------------------- generic list ops + def find_by_id(self, path: Iterable[str], id: str) -> Optional[Dict[str, Any]]: + """Return the entry with the given ``id`` at ``path``, or ``None``.""" + for entry in self._get_list(tuple(path)): + if entry.get("id") == id: + return entry + return None + + def replace_by_id(self, path: Iterable[str], id: str, new_entry: Dict[str, Any]) -> bool: + """Replace the entry with ``id`` at ``path``. Returns ``True`` if replaced.""" + lst = self._get_list(tuple(path)) + for idx, entry in enumerate(lst): + if entry.get("id") == id: + lst[idx] = new_entry + return True + return False + + def remove_by_id(self, path: Iterable[str], id: str) -> bool: + """Remove the entry with ``id`` at ``path``. Returns ``True`` if removed.""" + lst = self._get_list(tuple(path)) + for idx, entry in enumerate(lst): + if entry.get("id") == id: + lst.pop(idx) + return True + return False + + # ----------------------------------------------------- bulk accessors + def list_tables(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``data_sources.tables`` list.""" + return list(self._get_list(self.TABLES_PATH)) + + def list_metric_views(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``data_sources.metric_views`` list.""" + return list(self._get_list(self.METRIC_VIEWS_PATH)) + + def list_sample_questions(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``config.sample_questions`` list.""" + return list(self._get_list(self.SAMPLE_QUESTIONS_PATH)) + + def list_text_instructions(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.text_instructions`` list.""" + return list(self._get_list(self.TEXT_INSTRUCTIONS_PATH)) + + def list_example_sqls(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.example_question_sqls`` list.""" + return list(self._get_list(self.EXAMPLE_QUESTION_SQLS_PATH)) + + def list_join_specs(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.join_specs`` list.""" + return list(self._get_list(self.JOIN_SPECS_PATH)) + + def list_sql_filters(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.sql_snippets.filters`` list.""" + return list(self._get_list(self.SQL_SNIPPETS_FILTERS_PATH)) + + def list_sql_expressions(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.sql_snippets.expressions`` list.""" + return list(self._get_list(self.SQL_SNIPPETS_EXPRESSIONS_PATH)) + + def list_sql_measures(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``instructions.sql_snippets.measures`` list.""" + return list(self._get_list(self.SQL_SNIPPETS_MEASURES_PATH)) + + def list_benchmarks(self) -> List[Dict[str, Any]]: + """Return a shallow copy of the ``benchmarks.questions`` list.""" + return list(self._get_list(self.BENCHMARKS_PATH)) + + +__all__ = ["GenieSpaceBuilder"] diff --git a/databricks-mcp-server/tests/test_genie_space_builder.py b/databricks-mcp-server/tests/test_genie_space_builder.py new file mode 100644 index 00000000..ce123227 --- /dev/null +++ b/databricks-mcp-server/tests/test_genie_space_builder.py @@ -0,0 +1,336 @@ +"""Unit tests for GenieSpaceBuilder. + +These tests exercise the pure-Python authoring surface. They do not call the +Databricks workspace — only the builder's in-memory state is verified. +""" + +import json + +import pytest + +from databricks_mcp_server.tools.genie_space_builder import GenieSpaceBuilder + + +# ---------------------------------------------------------------- helpers +def _ids(entries): + return {e["id"] for e in entries} + + +# --------------------------------------------------------------- init / io +def test_init_sets_version_two(): + b = GenieSpaceBuilder() + assert b.to_dict()["version"] == 2 + + +def test_round_trip_preserves_unknown_fields(): + original = { + "version": 2, + "data_sources": {"tables": []}, + "some_future_field": {"foo": "bar"}, + } + b = GenieSpaceBuilder.from_json(original) + out = b.to_dict() + assert out["some_future_field"] == {"foo": "bar"} + + +def test_from_json_accepts_string(): + payload = json.dumps({"version": 2, "config": {"sample_questions": []}}) + b = GenieSpaceBuilder.from_json(payload) + assert b.list_sample_questions() == [] + + +def test_from_json_accepts_envelope(): + inner = json.dumps({"version": 2}) + envelope = { + "title": "My Space", + "description": "Desc", + "warehouse_id": "wh_1", + "serialized_space": inner, + } + b = GenieSpaceBuilder.from_json(envelope) + assert b.title == "My Space" + assert b.description == "Desc" + assert b.warehouse_id == "wh_1" + + +def test_from_json_rejects_non_dict_payload(): + with pytest.raises(TypeError): + GenieSpaceBuilder.from_json(123) # type: ignore[arg-type] + + +def test_to_envelope_contains_json_string(): + b = GenieSpaceBuilder(title="T", description="D", warehouse_id="W") + env = b.to_envelope() + assert env["title"] == "T" + assert env["description"] == "D" + assert env["warehouse_id"] == "W" + assert isinstance(env["serialized_space"], str) + assert json.loads(env["serialized_space"])["version"] == 2 + + +# -------------------------------------------------------------- data_sources +def test_add_table_and_metric_view(): + b = GenieSpaceBuilder() + b.add_table("cat.sch.orders", description="Order facts") + b.add_metric_view("cat.sch.metrics_orders", description="Order metrics") + assert len(b.list_tables()) == 1 + assert len(b.list_metric_views()) == 1 + assert b.list_tables()[0]["identifier"] == "cat.sch.orders" + assert b.list_tables()[0]["description"] == ["Order facts"] + + +def test_add_column_config_on_table(): + b = GenieSpaceBuilder() + b.add_table("cat.sch.orders") + b.add_column_config( + "cat.sch.orders", + "order_id", + description="Unique order identifier", + synonyms=["order", "id"], + enable_entity_matching=True, + enable_format_assistance=True, + ) + configs = b.list_tables()[0]["column_configs"] + assert configs[0]["column_name"] == "order_id" + assert configs[0]["description"] == ["Unique order identifier"] + assert configs[0]["enable_entity_matching"] is True + + +def test_add_column_config_replaces_existing(): + b = GenieSpaceBuilder() + b.add_table("cat.sch.orders") + b.add_column_config("cat.sch.orders", "status", description="v1") + b.add_column_config("cat.sch.orders", "status", description="v2") + configs = b.list_tables()[0]["column_configs"] + assert len(configs) == 1 + assert configs[0]["description"] == ["v2"] + + +def test_add_column_config_sorted_by_name(): + b = GenieSpaceBuilder() + b.add_table("cat.sch.orders") + b.add_column_config("cat.sch.orders", "zzz") + b.add_column_config("cat.sch.orders", "aaa") + b.add_column_config("cat.sch.orders", "mmm") + names = [c["column_name"] for c in b.list_tables()[0]["column_configs"]] + assert names == ["aaa", "mmm", "zzz"] + + +def test_add_column_config_exclude(): + b = GenieSpaceBuilder() + b.add_table("cat.sch.orders") + b.add_column_config("cat.sch.orders", "_rescued_data", exclude=True) + cc = b.list_tables()[0]["column_configs"][0] + assert cc == {"column_name": "_rescued_data", "exclude": True} + + +def test_add_column_config_unknown_identifier_raises(): + b = GenieSpaceBuilder() + with pytest.raises(KeyError): + b.add_column_config("cat.sch.missing", "col_a") + + +# -------------------------------------------------------- config & instructions +def test_sample_questions_get_unique_ids(): + b = GenieSpaceBuilder() + q1 = b.add_sample_question("Q1?") + q2 = b.add_sample_question("Q2?") + assert q1["id"] != q2["id"] + assert len(q1["id"]) == 32 # uuid4 hex + + +def test_text_instruction_roundtrips_content(): + b = GenieSpaceBuilder() + entry = b.add_text_instruction("## Best practices\n* tip one") + assert entry["content"] == ["## Best practices\n* tip one"] + + +def test_example_sql_adds_question_and_sql(): + b = GenieSpaceBuilder() + entry = b.add_example_sql( + "What were total sales?", + "SELECT SUM(total_amount) FROM orders", + ) + assert entry["question"] == ["What were total sales?"] + assert entry["sql"] == ["SELECT SUM(total_amount) FROM orders"] + + +def test_join_spec_full_fields(): + b = GenieSpaceBuilder() + entry = b.add_join_spec( + left_identifier="cat.sch.orders", + right_identifier="cat.sch.customers", + condition="`o`.`customer_id` = `c`.`customer_id`", + left_alias="o", + right_alias="c", + relationship_type="MANY_TO_ONE", + comment="Each order has one customer", + ) + assert entry["left"] == {"identifier": "cat.sch.orders", "alias": "o"} + assert entry["right"] == {"identifier": "cat.sch.customers", "alias": "c"} + assert entry["sql"][0] == "`o`.`customer_id` = `c`.`customer_id`" + assert entry["sql"][1] == "--rt=FROM_RELATIONSHIP_TYPE_MANY_TO_ONE--" + assert entry["comment"] == ["Each order has one customer"] + + +def test_join_spec_omits_empty_optional_fields(): + b = GenieSpaceBuilder() + entry = b.add_join_spec( + left_identifier="cat.sch.a", + right_identifier="cat.sch.b", + condition="a.id = b.id", + ) + assert "alias" not in entry["left"] + assert "alias" not in entry["right"] + assert entry["sql"] == ["a.id = b.id"] + assert "comment" not in entry + + +def test_join_spec_invalid_relationship_type_raises(): + b = GenieSpaceBuilder() + with pytest.raises(ValueError): + b.add_join_spec( + left_identifier="a", + right_identifier="b", + condition="a.id = b.id", + relationship_type="MANY_TO_FEW", + ) + + +# -------------------------------------------------------------- sql_snippets +def test_sql_filter_measure_expression(): + b = GenieSpaceBuilder() + b.add_sql_filter( + "orders.status = 'Confirmed'", + display_name="Confirmed", + synonyms=["active"], + comment="Only confirmed orders", + ) + b.add_sql_measure( + "total_revenue", + "SUM(orders.total_amount)", + display_name="Total Revenue", + ) + b.add_sql_expression( + "order_year", + "YEAR(orders.order_date)", + display_name="Order Year", + ) + assert len(b.list_sql_filters()) == 1 + assert len(b.list_sql_measures()) == 1 + assert len(b.list_sql_expressions()) == 1 + + f0 = b.list_sql_filters()[0] + assert f0["display_name"] == "Confirmed" + assert f0["sql"] == ["orders.status = 'Confirmed'"] + assert f0["comment"] == ["Only confirmed orders"] + + m0 = b.list_sql_measures()[0] + assert m0["alias"] == "total_revenue" + assert m0["sql"] == ["SUM(orders.total_amount)"] + assert "name" not in m0 # API field is `alias`, not `name` + + e0 = b.list_sql_expressions()[0] + assert e0["alias"] == "order_year" + assert e0["sql"] == ["YEAR(orders.order_date)"] + + +def test_sql_snippet_strips_empty_fields(): + b = GenieSpaceBuilder() + entry = b.add_sql_measure("cnt", "COUNT(*)") + assert "display_name" not in entry + assert "comment" not in entry + assert entry["sql"] == ["COUNT(*)"] + + +# --------------------------------------------------------------- benchmarks +def test_benchmark_structure(): + b = GenieSpaceBuilder() + entry = b.add_benchmark( + "How many orders?", + "SELECT COUNT(*) FROM orders", + ) + assert entry["question"] == ["How many orders?"] + assert entry["answer"] == [{"format": "SQL", "content": ["SELECT COUNT(*) FROM orders"]}] + + +# ------------------------------------------------------------- generic ops +def test_find_replace_remove_by_id(): + b = GenieSpaceBuilder() + q1 = b.add_sample_question("Q1?") + q2 = b.add_sample_question("Q2?") + + assert b.find_by_id(b.SAMPLE_QUESTIONS_PATH, q1["id"]) == q1 + assert b.find_by_id(b.SAMPLE_QUESTIONS_PATH, "missing") is None + + assert b.replace_by_id( + b.SAMPLE_QUESTIONS_PATH, + q1["id"], + {"id": q1["id"], "question": ["Updated?"]}, + ) + assert b.find_by_id(b.SAMPLE_QUESTIONS_PATH, q1["id"])["question"] == ["Updated?"] + + assert b.remove_by_id(b.SAMPLE_QUESTIONS_PATH, q2["id"]) + assert b.find_by_id(b.SAMPLE_QUESTIONS_PATH, q2["id"]) is None + assert len(b.list_sample_questions()) == 1 + + +def test_replace_by_id_returns_false_for_missing(): + b = GenieSpaceBuilder() + b.add_sample_question("Q?") + assert not b.replace_by_id(b.SAMPLE_QUESTIONS_PATH, "does-not-exist", {}) + + +def test_ids_are_unique_across_slots(): + b = GenieSpaceBuilder() + b.add_sample_question("A?") + b.add_text_instruction("T") + b.add_example_sql("Q?", "SELECT 1") + b.add_join_spec("cat.sch.a", "cat.sch.b", "a.id = b.id") + b.add_sql_filter("a.x = 1") + b.add_sql_expression("alias", "YEAR(a.d)") + b.add_sql_measure("m", "COUNT(*)") + b.add_benchmark("bq", "SELECT 1") + + all_ids = set() + for path in GenieSpaceBuilder.ID_LIST_PATHS: + for entry in b._get_list(path): + all_ids.add(entry["id"]) + assert len(all_ids) == 8 + + +def test_explicit_id_honored(): + b = GenieSpaceBuilder() + entry = b.add_sample_question("Q?", id="deadbeef" * 4) + assert entry["id"] == "deadbeef" * 4 + + +# ------------------------------------------------------- full roundtrip test +def test_full_build_to_import_envelope(): + """End-to-end: build a space, export envelope, reload, confirm equivalence.""" + src = GenieSpaceBuilder(title="Sales", description="D", warehouse_id="wh") + src.add_table("cat.sch.orders", description="Orders fact") + src.add_table("cat.sch.customers", description="Customers dim") + src.add_column_config("cat.sch.orders", "order_id", description="id") + src.add_sample_question("What were sales last month?") + src.add_example_sql("Total sales?", "SELECT SUM(amt) FROM cat.sch.orders") + src.add_join_spec( + left_identifier="cat.sch.orders", + right_identifier="cat.sch.customers", + condition="orders.customer_id = customers.customer_id", + relationship_type="MANY_TO_ONE", + ) + src.add_sql_measure("rev", "SUM(orders.amt)", display_name="Revenue") + src.add_text_instruction("Always round to 2 decimals.") + src.add_benchmark("How many?", "SELECT COUNT(*) FROM orders") + + envelope = src.to_envelope() + dst = GenieSpaceBuilder.from_json(envelope) + + assert dst.title == "Sales" + assert dst.warehouse_id == "wh" + assert len(dst.list_tables()) == 2 + assert len(dst.list_join_specs()) == 1 + assert len(dst.list_sql_measures()) == 1 + assert len(dst.list_benchmarks()) == 1 + assert dst.to_dict() == src.to_dict() diff --git a/databricks-skills/databricks-genie/SKILL.md b/databricks-skills/databricks-genie/SKILL.md index 82332476..6a0b63a3 100644 --- a/databricks-skills/databricks-genie/SKILL.md +++ b/databricks-skills/databricks-genie/SKILL.md @@ -174,6 +174,7 @@ manage_genie( ## Reference Files - [spaces.md](spaces.md) - Creating and managing Genie Spaces +- [spaces-authoring.md](spaces-authoring.md) - Building rich spaces with joins, SQL snippets, example SQL, text instructions, and benchmarks via the `GenieSpaceBuilder` helper - [conversation.md](conversation.md) - Asking questions via the Conversation API ## Prerequisites diff --git a/databricks-skills/databricks-genie/spaces-authoring.md b/databricks-skills/databricks-genie/spaces-authoring.md new file mode 100644 index 00000000..3f98a7ca --- /dev/null +++ b/databricks-skills/databricks-genie/spaces-authoring.md @@ -0,0 +1,337 @@ +# Authoring a Rich Genie Space + +This guide covers building a Genie Space with the full `serialized_space` +surface populated — joins, SQL snippets (filters / expressions / measures), +text instructions, example question SQL, benchmarks, and column configs — not +just the minimal create flow. + +The minimal `manage_genie(action="create_or_update", ...)` call takes +`display_name`, `table_identifiers`, `description`, and `sample_questions` +and stops there. For richer spaces you push a fully-populated +`serialized_space` through `manage_genie(action="import")`. This doc walks +through building that payload end-to-end. + +See [references/schema.md](references/schema.md) for the full field reference +and [references/best-practices.md](references/best-practices.md) for +authoring conventions (added in PR #473 — cross-reference once merged). + +## Where the artifacts live + +| Artifact | `serialized_space` path | Purpose | +|---|---|---| +| Tables | `data_sources.tables[]` | UC tables Genie can query directly | +| Metric views | `data_sources.metric_views[]` | UC metric views (semantic layer: dimensions, measures, joins in YAML) | +| Column configs | `data_sources.{tables,metric_views}[].column_configs[]` | Per-column description, synonyms, entity matching, format assistance, exclusion | +| Sample questions | `config.sample_questions[]` | Starter questions shown on the space home page | +| Text instructions | `instructions.text_instructions[]` | Markdown guidance the model reads on every query | +| Example question SQL | `instructions.example_question_sqls[]` | Certified Q&A pairs — strongest steering signal | +| Join specs | `instructions.join_specs[]` | Declared joins the model uses instead of guessing | +| SQL snippets — measures | `instructions.sql_snippets.measures[]` | Named reusable aggregates (`SUM(...)`, `COUNT(DISTINCT ...)`) | +| SQL snippets — filters | `instructions.sql_snippets.filters[]` | Named reusable WHERE clauses | +| SQL snippets — expressions | `instructions.sql_snippets.expressions[]` | Named reusable SELECT expressions (date extracts, CASE bucketing) | +| Benchmarks | `benchmarks.questions[]` | Ground-truth Q&A pairs for quality evaluation | + +Dimensions and measures for a Genie Space are **not** fields in +`serialized_space` — they live in Unity Catalog **metric views** (created with +`CREATE VIEW ... WITH METRICS LANGUAGE YAML`), which Genie consumes when you +list them under `data_sources.metric_views[]`. Build the semantic layer in UC, +then reference the metric view by its fully-qualified name. + +## The builder + +Use `GenieSpaceBuilder` from `databricks_mcp_server.tools.genie_space_builder` +to author payloads without hand-rolling JSON. It provides path constants and +typed `add_*` / `replace_*` / `find_by_id` / `to_json` / `from_json` +helpers, handles ID generation per API spec, and preserves unknown fields on +round-trip. + +```python +from databricks_mcp_server.tools.genie_space_builder import GenieSpaceBuilder + +builder = GenieSpaceBuilder( + title="Sales Analytics", + description="Explore sales data with natural language", + warehouse_id="abc123", +) +``` + +All snippets below assume a `builder` in scope. + +## The seven-step authoring pipeline + +The steps below mirror the pipeline used by +[`sunnysingh-db/ai-genie-space-generator`](https://github.com/sunnysingh-db/ai-genie-space-generator), +a public Databricks Solutions reference implementation. You can run each step +by hand or drive it with an LLM — the builder does not care. + +### 1. Scan table metadata + +Inspect schemas and sample data before you decide which slots to fill: + +```python +get_table_stats_and_schema( + catalog="my_catalog", + schema="sales", + table_stat_level="DEEP", +) +``` + +Use the output to identify temporal, categorical, and numeric columns, +cardinalities, and plausible join keys. + +### 2. Build the semantic layer (dimensions + measures) via metric views + +Create a UC metric view per fact table. The YAML body defines dimensions, +measures, and joins in the UC semantic-layer format. Then register the view +with the builder: + +```python +execute_sql( + sql=""" + CREATE OR REPLACE VIEW my_catalog.sales.metrics_orders + WITH METRICS + LANGUAGE YAML + AS $$ + version: 0.1 + source: my_catalog.sales.orders + dimensions: + - name: order_date + expr: order_date + - name: order_status + expr: status + measures: + - name: total_revenue + expr: SUM(total_amount) + - name: order_count + expr: COUNT(*) + joins: + - name: customers + source: my_catalog.sales.customers + on: orders.customer_id = customers.customer_id + $$ + """ +) + +builder.add_metric_view( + "my_catalog.sales.metrics_orders", + description="Order facts with revenue and count measures", +) +``` + +Raw tables still go under `data_sources.tables` — include them when users +need access to columns that are not covered by the metric view. + +### 3. Declare joins + +Even when metric views carry join definitions internally, adding explicit +`join_specs` at the space level gives the model reliable cross-table guidance +for ad-hoc queries: + +```python +builder.add_join_spec( + left_identifier="my_catalog.sales.orders", + right_identifier="my_catalog.sales.customers", + condition="`o`.`customer_id` = `c`.`customer_id`", + left_alias="o", + right_alias="c", + relationship_type="MANY_TO_ONE", + comment="Each order is placed by one customer", +) +``` + +Ground your join columns against the actual column names — do not assume. +`relationship_type` must be one of `ONE_TO_ONE`, `ONE_TO_MANY`, +`MANY_TO_ONE`, `MANY_TO_MANY`. The builder encodes it as a marker inside the +`sql` list (the format the API uses to round-trip the value). + +### 4. Add table and column descriptions + +Genie reads descriptions when reasoning about schema. Add them at the table +level and at the column level (via `column_configs`) for anything non-obvious: + +```python +builder.add_table( + "my_catalog.sales.orders", + description=( + "Order-level fact table. One row per order. " + "Joined to customers via customer_id." + ), +) +builder.add_column_config( + "my_catalog.sales.orders", + "status", + description="Order lifecycle state", + synonyms=["state", "order state"], + enable_entity_matching=True, +) +builder.add_column_config( + "my_catalog.sales.orders", + "_rescued_data", + exclude=True, +) +``` + +`enable_entity_matching` on low-cardinality categorical columns lets Genie +resolve "confirmed orders" to `status = 'Confirmed'` even when the user's +phrasing does not match the stored value exactly. + +### 5. Write sample questions, with SQL for the important ones + +Sample questions appear as clickable tiles on the space home page. Pair each +strategic question with certified SQL via `add_example_sql` — those pairs +anchor the model's generations more strongly than natural language +instructions alone: + +```python +builder.add_sample_question("What were total sales last month?") +builder.add_sample_question("Which product categories grew the most this quarter?") +builder.add_sample_question("Who are our top 10 customers by revenue?") + +builder.add_example_sql( + "What were total sales last month?", + """ + SELECT SUM(total_amount) AS total_sales + FROM my_catalog.sales.orders + WHERE order_date >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL 1 MONTH) + AND order_date < DATE_TRUNC('month', CURRENT_DATE) + """, +) +``` + +Write questions in everyday business language — never reference column names, +struct paths, or underscores in the question text itself. + +### 6. Add reusable SQL snippets (measures, filters, expressions) + +Snippets are named SQL fragments Genie can substitute by name. They reduce +drift between queries and cut down on ad-hoc rewrites: + +```python +# Measures — named aggregates (the API field is `alias`, not `name`) +builder.add_sql_measure( + alias="total_revenue", + sql="SUM(orders.total_amount)", + display_name="Total Revenue", + synonyms=["revenue", "sales"], + comment="Top-line revenue across all orders.", +) +builder.add_sql_measure( + alias="cancellation_rate", + sql="COUNT(CASE WHEN orders.status='Cancelled' THEN 1 END) * 100.0 / COUNT(*)", + display_name="Cancellation Rate", +) + +# Filters — named predicates (no `name` field — uses `display_name`) +builder.add_sql_filter( + sql="orders.status = 'Confirmed'", + display_name="Confirmed orders only", +) + +# Expressions — named SELECT-list fragments +builder.add_sql_expression( + alias="order_year", + sql="YEAR(orders.order_date)", + display_name="Order year", +) +``` + +Keep formulas short, valid, and table-prefixed (`table_name.column_name`). +Genie degrades when snippets reference ambiguous bare columns. The builder +stores `sql` and `comment` as single-element lists — the format the Genie +API uses on the wire. + +### 7. Add text instructions and benchmarks + +Text instructions are free-form markdown the model reads on every query. +Keep them short, actionable, and in business language: + +```python +builder.add_text_instruction( + """ + ## Best Practices + * Always round percentages to two decimal places in summaries. + * When users ask about a KPI without specifying a time range, + ask for the range before answering. + * Use the `total_revenue` measure for any revenue question. + """.strip() +) +``` + +Benchmarks are optional ground-truth Q&A pairs used by the Genie quality +evaluator. Include a handful for the questions you care most about: + +```python +builder.add_benchmark( + "What were total sales last month?", + """ + SELECT SUM(total_amount) + FROM my_catalog.sales.orders + WHERE order_date >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL 1 MONTH) + AND order_date < DATE_TRUNC('month', CURRENT_DATE) + """, +) +``` + +## Push the payload + +Export the envelope and pass it to `manage_genie(action="import")`: + +```python +envelope = builder.to_envelope() + +manage_genie( + action="import", + warehouse_id=envelope["warehouse_id"], + serialized_space=envelope["serialized_space"], + title=envelope["title"], + description=envelope["description"], +) +``` + +To update an existing space, use `action="create_or_update"` with the same +`serialized_space` argument — the server will merge top-level overrides +(`display_name`, `description`, `warehouse_id`) on top of the serialized +payload. + +## Round-tripping an existing space + +Fetch an existing space, patch it with the builder, and push it back: + +```python +exported = manage_genie(action="export", space_id="space_123") +builder = GenieSpaceBuilder.from_json(exported) + +builder.add_example_sql( + "What's the average order value by segment?", + "SELECT segment, AVG(total_amount) FROM ... GROUP BY segment", +) +builder.add_sql_measure("avg_order_value", "AVG(orders.total_amount)") + +manage_genie( + action="import", + **builder.to_envelope(), +) +``` + +The builder preserves any fields it does not model, so round-trips are safe +against schema additions on the Genie side. + +## Auto-populating slots from metadata + +The seven-step pipeline maps cleanly to an LLM-driven workflow: scan +metadata, then prompt a model to emit dimensions, measures, joins, +descriptions, sample questions, and snippets. See +[`sunnysingh-db/ai-genie-space-generator`](https://github.com/sunnysingh-db/ai-genie-space-generator) +for a reference implementation that pairs a metadata scanner with a +Databricks Foundation Model API pipeline and writes the output through the +same `serialized_space` contract. The builder in this skill is the authoring +layer you would plug that pipeline into. + +## Related skills + +- [databricks-unity-catalog](../databricks-unity-catalog/SKILL.md) — create + the UC metric views that carry the semantic layer +- [databricks-synthetic-data-gen](../databricks-synthetic-data-gen/SKILL.md) — + generate tables to populate a Genie Space for demos +- [databricks-spark-declarative-pipelines](../databricks-spark-declarative-pipelines/SKILL.md) + — build bronze/silver/gold tables consumed by Genie