Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/llama_stack/core/access_control/access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ def default_policy() -> list[AccessRule]:
return [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]],
when=["user in owners " + name],
)
for name in ["roles", "teams", "projects", "namespaces"]
] + [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user is owner"],
),
AccessRule(
permit=Scope(actions=list(Action)),
when=["resource is unowned"],
),
]

Expand Down
18 changes: 14 additions & 4 deletions src/llama_stack/core/access_control/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def owners_values(self, resource: ProtectedResource) -> list[str] | None:
return None

def matches(self, resource: ProtectedResource, user: User) -> bool:
required = self.owners_values(resource)
if not required:
return True
defined = self.owners_values(resource)
if not defined:
return False
if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]:
return False
user_values = user.attributes[self.name]
for value in required:
for value in defined:
if value in user_values:
return True
return False
Expand Down Expand Up @@ -106,6 +106,14 @@ def __repr__(self):
return "user is not owner"


class ResourceIsUnowned:
def matches(self, resource: ProtectedResource, user: User) -> bool:
return not resource.owner

def __repr__(self):
return "resource is unowned"


def parse_condition(condition: str) -> Condition:
words = condition.split()
match words:
Expand All @@ -121,6 +129,8 @@ def parse_condition(condition: str) -> Condition:
return UserInOwnersList(name)
case ["user", "not", "in", "owners", name]:
return UserNotInOwnersList(name)
case ["resource", "is", "unowned"]:
return ResourceIsUnowned()
case _:
raise ValueError(f"Invalid condition: {condition}")

Expand Down
80 changes: 38 additions & 42 deletions src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@
# WARNING: If default_policy() changes, this constant must be updated accordingly
# or SQL filtering will fall back to conservative mode (safe but less performant)
#
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
# This policy represents: "Permit all actions when user is in owners list for ANY attribute category"
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
# - Public records (no access_attributes) are always accessible
# - Records with access_attributes require user to match ALL categories that exist in the resource
# - Missing categories in the resource are treated as "no restriction" (allow)
# - Records with access_attributes require user to match ANY category that exists in the resource
# - Within each category, user needs ANY matching value (OR logic)
# - Between categories, user needs ALL categories to match (AND logic)
# - Between categories, user needs ANY category to match (OR logic)
SQL_OPTIMIZED_POLICY = [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
when=["user in owners " + name],
)
for name in ["roles", "teams", "projects", "namespaces"]
] + [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user is owner"],
),
AccessRule(
permit=Scope(actions=list(Action)),
when=["resource is unowned"],
),
]

Expand Down Expand Up @@ -279,53 +288,40 @@ def _get_public_access_conditions(self) -> list[str]:

Public records are those with:
- owner_principal = '' (empty string)
- access_attributes is either SQL NULL or JSON null

Note: Different databases serialize None differently:
- SQLite: None → JSON null (text = 'null')
- Postgres: None → SQL NULL (IS NULL)
The policy "resource is unowned" only checks if owner_principal is empty,
regardless of access_attributes.
"""
conditions = ["owner_principal = ''"]
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Accept both SQL NULL and JSON null for Postgres compatibility
# This handles both old rows (SQL NULL) and new rows (JSON null)
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
# SQLite serializes None as JSON null
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
return ["owner_principal = ''"]

def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.

Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
This means user must match ANY attribute category that exists in the resource (OR logic).
"""
base_conditions = self._get_public_access_conditions()
user_attr_conditions = []

if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")

if value_conditions:
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")

if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)

if current_user:
# Add "user is owner" condition - user's principal matches owner_principal
escaped_principal = current_user.principal.replace("'", "''")
base_conditions.append(f"owner_principal = '{escaped_principal}'")

# Add "user in owners" conditions for attribute matching
if current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")

if value_conditions:
# User matches this category if any of their values match
user_matches_category = f"({' OR '.join(value_conditions)})"
base_conditions.append(user_matches_category)
return f"({' OR '.join(base_conditions)})"

def _build_conservative_where_clause(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)

# Test a user that has all roles and teams (should generate SQL)
# owner_principal = ''
# owner_principal = 'super-user'
# ((JSON_EXTRACT(access_attributes, '$.roles') LIKE '%"admin"%') OR (JSON_EXTRACT(access_attributes, '$.roles') LIKE '%"user"%'))
# ((JSON_EXTRACT(access_attributes, '$.teams') LIKE '%"dev"%') OR (JSON_EXTRACT(access_attributes, '$.teams') LIKE '%"qa"%'))
super_user = User("super-user", {"roles": ["admin", "user"], "teams": ["dev", "qa"]})
mock_get_authenticated_user.return_value = super_user
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 6

finally:
# Clean up records
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
Expand Down
17 changes: 9 additions & 8 deletions tests/unit/server/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")

mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["user"], "teams": ["other-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
Expand Down Expand Up @@ -154,16 +154,16 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
)
await registry.register(model)
mock_get_authenticated_user.return_value = User(
"test-user",
"differentuser",
{
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
with pytest.raises(ValueError):
await routing_table.get_model("model-empty-attrs")
all_models = await routing_table.list_models()
model_ids = [m.identifier for m in all_models.data]
assert "model-empty-attrs" in model_ids
assert "model-empty-attrs" not in model_ids


@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
Expand Down Expand Up @@ -223,7 +223,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert registered_model.owner.attributes["projects"] == ["llama-3"]

# Verify another user without matching attributes can't access it
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
mock_get_authenticated_user.return_value = User("test-user2", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")

Expand Down Expand Up @@ -363,6 +363,7 @@ def test_permit_when():


def test_permit_unless():
# permit unless both conditions are met
config = """
- permit:
principal: user-1
Expand All @@ -377,10 +378,10 @@ def test_permit_unless():
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
owner=User("testuser", {"namespaces": ["foo"], "teams": ["ml-team"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"], "teams": ["ml-team"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))


Expand Down
Loading