Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions databricks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Databricks Asset Bundle definition for mcp_databricks_filtering.
# https://docs.databricks.com/dev-tools/bundles/index.html
bundle:
name: mcp_databricks_filtering

include:
- resources/*.yml
- resources/*/*.yml

artifacts:
python_artifact:
type: whl
build: uv build --wheel

variables:
workspace_host:
description: Databricks workspace URL (e.g. https://adb-xxxx.azuredatabricks.net)
catalog:
description: Default catalog
default: main
schema:
description: Default schema
default: default

targets:
dev:
mode: development
default: true
workspace:
host: ${var.workspace_host}
variables:
schema: ${workspace.current_user.short_name}

prod:
mode: production
workspace:
host: ${var.workspace_host}
root_path: /Workspace/Users/${workspace.current_user.userName}/.bundle/${bundle.name}/${bundle.target}
variables:
schema: prod
9 changes: 9 additions & 0 deletions fixtures/.gitkeep
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test fixtures directory

Add JSON or CSV files here. In tests, use them with `load_fixture()`:

```
def test_using_fixture(load_fixture):
data = load_fixture("my_data.json")
assert len(data) >= 1
```
38 changes: 38 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[project]
name = "mcp_databricks_filtering"
version = "0.1.0"
description = "Tag-based table filtering for the Databricks MCP server"
requires-python = ">=3.10"
dependencies = [
"databricks-sdk>=0.30,<1.0",
"sqlglot>=25.0,<27.0",
]

[dependency-groups]
dev = [
"pytest>=8.0",
"ruff>=0.6",
]

[project.scripts]
mcp-table-filter = "mcp_databricks_filtering.main:main"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/mcp_databricks_filtering"]

[tool.ruff]
line-length = 100
target-version = "py310"

[tool.ruff.lint]
select = ["E", "F", "I", "B", "UP", "SIM"]

[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [
"online: tests that require a live Databricks connection",
]
34 changes: 34 additions & 0 deletions resources/sample_job.job.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
resources:
jobs:
table_filter_refresh:
name: mcp_table_filter_refresh

trigger:
periodic:
interval: 1
unit: DAYS

parameters:
- name: catalog
default: ${var.catalog}
- name: schema
default: ${var.schema}

tasks:
- task_key: refresh_filter_cache
python_wheel_task:
package_name: mcp_databricks_filtering
entry_point: mcp-table-filter
parameters:
- "--tag-name"
- "mcp-ready"
- "--tag-value"
- "yes"
environment_key: default

environments:
- environment_key: default
spec:
environment_version: "4"
dependencies:
- ../dist/*.whl
85 changes: 85 additions & 0 deletions scripts/smoke_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Manual smoke-check of the MCP table filter against a live Databricks workspace.

This is NOT a pytest test — run it directly:

databricks auth login --profile <profile>

MCP_TABLE_FILTER_TAG_NAME=mcp-ready MCP_TABLE_FILTER_TAG_VALUE=yes \\
uv run python scripts/smoke_check.py
"""

from __future__ import annotations

import json
import sys

from mcp_databricks_filtering.config import FilterConfig
from mcp_databricks_filtering.table_filter import TableTagFilter


def banner(label: str) -> None:
bar = "=" * 60
print(f"\n{bar}\n {label}\n{bar}")


def main() -> int:
config = FilterConfig.from_env()
if not config.is_enabled:
print("Set MCP_TABLE_FILTER_TAG_NAME first (or use --tag-name).", file=sys.stderr)
return 1

f = TableTagFilter(config=config)
print(f"Filter: {config.tag_name}={config.tag_value or '<any>'}")
print(f"Cache TTL: {config.cache_ttl_seconds}s")

banner("1. Allowed tables")
try:
allowed = f.get_allowed_tables()
except Exception as exc:
print(f"ERROR: {exc}", file=sys.stderr)
return 2

print(f"\nFound {len(allowed)} allowed table(s):\n")
for cat, sch, tbl in sorted(allowed):
print(f" - {cat}.{sch}.{tbl}")

banner("2. Filter a fake table list")
fake = [
{"name": "delta_bronze_analystratings"},
{"name": "delta_bronze_price"},
{"name": "delta_bronze_dividends"},
{"name": "some_secret_table"},
]
filtered = f.filter_table_list(fake, "main", "eod")
print(json.dumps({
"input": [t["name"] for t in fake],
"filtered": [t["name"] for t in filtered],
}, indent=2))

banner("3. SQL validation cases")
cases = [
("Allowed table", "SELECT * FROM main.eod.delta_bronze_analystratings"),
("Blocked table", "SELECT * FROM main.eod.some_secret_table"),
("System table", "SELECT * FROM system.information_schema.table_tags"),
("USE CATALOG", "USE CATALOG main"),
("SHOW TABLES", "SHOW TABLES IN main.eod"),
("Join blocked", "SELECT a.* FROM main.eod.delta_bronze_analystratings a "
"JOIN main.eod.secret b ON a.id = b.id"),
("CTE allowed", "WITH t AS (SELECT * FROM main.eod.delta_bronze_analystratings) "
"SELECT * FROM t"),
("Garbage SQL", "THIS IS NOT VALID @@@@"),
]
for label, sql in cases:
try:
f.validate_sql(sql)
print(f" [ALLOWED] {label}")
except PermissionError as exc:
print(f" [BLOCKED] {label}: {exc}")

print("\nSmoke check complete.")
return 0


if __name__ == "__main__":
sys.exit(main())
17 changes: 17 additions & 0 deletions src/mcp_databricks_filtering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Tag-based table filtering for the Databricks MCP server."""

from mcp_databricks_filtering.config import FilterConfig
from mcp_databricks_filtering.table_filter import (
TableTagFilter,
get_table_filter,
reset_singleton,
)

__all__ = [
"FilterConfig",
"TableTagFilter",
"get_table_filter",
"reset_singleton",
]

__version__ = "0.1.0"
76 changes: 76 additions & 0 deletions src/mcp_databricks_filtering/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Configuration for the MCP table filter."""

from __future__ import annotations

import os
import re
from dataclasses import dataclass

_TAG_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_-]*$")


@dataclass(frozen=True)
class FilterConfig:
"""Immutable configuration for the table filter.

Use ``FilterConfig.from_env()`` to build from environment variables, or
construct directly for tests / non-env-driven usage.

Attributes:
tag_name: UC tag key to filter on. Empty string disables filtering.
tag_value: Required tag value (empty = match any value of ``tag_name``).
cache_ttl_seconds: How long to cache the allowlist (default 300s).
warehouse_id: Optional explicit SQL warehouse to use for queries.
fail_closed: If True (default), unparseable SQL is rejected rather than
allowed through. Disable only in trusted, debug-only contexts.
"""

tag_name: str = ""
tag_value: str = ""
cache_ttl_seconds: int = 300
warehouse_id: str | None = None
fail_closed: bool = True

def __post_init__(self) -> None:
if self.tag_name and not _TAG_NAME_RE.match(self.tag_name):
raise ValueError(
f"Invalid tag_name {self.tag_name!r}: must match [A-Za-z_][A-Za-z0-9_-]*"
)
if self.cache_ttl_seconds < 0:
raise ValueError(f"cache_ttl_seconds must be >= 0, got {self.cache_ttl_seconds}")

@property
def is_enabled(self) -> bool:
return bool(self.tag_name)

@classmethod
def from_env(cls, env: dict | None = None) -> FilterConfig:
"""Build a FilterConfig from environment variables.

Recognized env vars (all optional):
- MCP_TABLE_FILTER_TAG_NAME
- MCP_TABLE_FILTER_TAG_VALUE
- MCP_TABLE_FILTER_CACHE_TTL (integer seconds)
- MCP_TABLE_FILTER_WAREHOUSE_ID
- MCP_TABLE_FILTER_FAIL_CLOSED (true/false, default true)
"""
env = env if env is not None else os.environ

ttl_raw = env.get("MCP_TABLE_FILTER_CACHE_TTL", "300").strip()
try:
cache_ttl = int(ttl_raw) if ttl_raw else 300
except ValueError as exc:
raise ValueError(
f"MCP_TABLE_FILTER_CACHE_TTL must be an integer, got {ttl_raw!r}"
) from exc

fail_closed_raw = env.get("MCP_TABLE_FILTER_FAIL_CLOSED", "true").strip().lower()
fail_closed = fail_closed_raw not in {"0", "false", "no", "off"}

return cls(
tag_name=env.get("MCP_TABLE_FILTER_TAG_NAME", "").strip(),
tag_value=env.get("MCP_TABLE_FILTER_TAG_VALUE", "").strip(),
cache_ttl_seconds=cache_ttl,
warehouse_id=env.get("MCP_TABLE_FILTER_WAREHOUSE_ID", "").strip() or None,
fail_closed=fail_closed,
)
60 changes: 60 additions & 0 deletions src/mcp_databricks_filtering/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""CLI entry point — prints the current allowlist."""

from __future__ import annotations

import argparse
import sys
from dataclasses import replace

from mcp_databricks_filtering.config import FilterConfig
from mcp_databricks_filtering.table_filter import TableTagFilter


def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="MCP Databricks table filter utility")
parser.add_argument(
"--tag-name",
default=None,
help="UC tag key to filter on (overrides MCP_TABLE_FILTER_TAG_NAME)",
)
parser.add_argument(
"--tag-value",
default=None,
help="Required tag value (overrides MCP_TABLE_FILTER_TAG_VALUE)",
)
args = parser.parse_args(argv)

config = FilterConfig.from_env()
if args.tag_name is not None or args.tag_value is not None:
config = replace(
config,
tag_name=args.tag_name if args.tag_name is not None else config.tag_name,
tag_value=args.tag_value if args.tag_value is not None else config.tag_value,
)

if not config.is_enabled:
print(
"Table filter is DISABLED (set MCP_TABLE_FILTER_TAG_NAME or use --tag-name)",
file=sys.stderr,
)
return 1

f = TableTagFilter(config=config)
print(f"Filter: {config.tag_name}={config.tag_value or '<any>'}")
print(f"Cache TTL: {config.cache_ttl_seconds}s")
print()

try:
allowed = f.get_allowed_tables()
except Exception as exc:
print(f"ERROR querying allowed tables: {exc}", file=sys.stderr)
return 2

print(f"Found {len(allowed)} allowed table(s):")
for cat, sch, tbl in sorted(allowed):
print(f" {cat}.{sch}.{tbl}")
return 0


if __name__ == "__main__":
sys.exit(main())
Loading