diff --git a/src/llama_stack/core/access_control/access_control.py b/src/llama_stack/core/access_control/access_control.py index 9449132720..b69674a525 100644 --- a/src/llama_stack/core/access_control/access_control.py +++ b/src/llama_stack/core/access_control/access_control.py @@ -66,7 +66,17 @@ def default_policy() -> list[AccessRule]: return [ AccessRule( permit=Scope(actions=list(Action)), - when=["user in owners " + name for name in ["roles", "teams", "projects", "namespaces"]], + when=["user in owners " + name], + ) + for name in ["roles", "teams", "projects", "namespaces"] + ] + [ + AccessRule( + permit=Scope(actions=list(Action)), + when=["user is owner"], + ), + AccessRule( + permit=Scope(actions=list(Action)), + when=["resource is unowned"], ), ] diff --git a/src/llama_stack/core/access_control/conditions.py b/src/llama_stack/core/access_control/conditions.py index f6f2eb7424..4aa346eaa9 100644 --- a/src/llama_stack/core/access_control/conditions.py +++ b/src/llama_stack/core/access_control/conditions.py @@ -38,13 +38,13 @@ def owners_values(self, resource: ProtectedResource) -> list[str] | None: return None def matches(self, resource: ProtectedResource, user: User) -> bool: - required = self.owners_values(resource) - if not required: - return True + defined = self.owners_values(resource) + if not defined: + return False if not user.attributes or self.name not in user.attributes or not user.attributes[self.name]: return False user_values = user.attributes[self.name] - for value in required: + for value in defined: if value in user_values: return True return False @@ -106,6 +106,14 @@ def __repr__(self): return "user is not owner" +class ResourceIsUnowned: + def matches(self, resource: ProtectedResource, user: User) -> bool: + return not resource.owner + + def __repr__(self): + return "resource is unowned" + + def parse_condition(condition: str) -> Condition: words = condition.split() match words: @@ -121,6 +129,8 @@ def parse_condition(condition: str) -> Condition: return UserInOwnersList(name) case ["user", "not", "in", "owners", name]: return UserNotInOwnersList(name) + case ["resource", "is", "unowned"]: + return ResourceIsUnowned() case _: raise ValueError(f"Invalid condition: {condition}") diff --git a/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py index 9cd9a8a29b..792a34e821 100644 --- a/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py @@ -23,17 +23,26 @@ # WARNING: If default_policy() changes, this constant must be updated accordingly # or SQL filtering will fall back to conservative mode (safe but less performant) # -# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories" +# This policy represents: "Permit all actions when user is in owners list for ANY attribute category" # The corresponding SQL logic is implemented in _build_default_policy_where_clause(): # - Public records (no access_attributes) are always accessible -# - Records with access_attributes require user to match ALL categories that exist in the resource -# - Missing categories in the resource are treated as "no restriction" (allow) +# - Records with access_attributes require user to match ANY category that exists in the resource # - Within each category, user needs ANY matching value (OR logic) -# - Between categories, user needs ALL categories to match (AND logic) +# - Between categories, user needs ANY category to match (OR logic) SQL_OPTIMIZED_POLICY = [ AccessRule( permit=Scope(actions=list(Action)), - when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"], + when=["user in owners " + name], + ) + for name in ["roles", "teams", "projects", "namespaces"] +] + [ + AccessRule( + permit=Scope(actions=list(Action)), + when=["user is owner"], + ), + AccessRule( + permit=Scope(actions=list(Action)), + when=["resource is unowned"], ), ] @@ -279,53 +288,40 @@ def _get_public_access_conditions(self) -> list[str]: Public records are those with: - owner_principal = '' (empty string) - - access_attributes is either SQL NULL or JSON null - Note: Different databases serialize None differently: - - SQLite: None → JSON null (text = 'null') - - Postgres: None → SQL NULL (IS NULL) + The policy "resource is unowned" only checks if owner_principal is empty, + regardless of access_attributes. """ - conditions = ["owner_principal = ''"] - if self.database_type == StorageBackendType.SQL_POSTGRES.value: - # Accept both SQL NULL and JSON null for Postgres compatibility - # This handles both old rows (SQL NULL) and new rows (JSON null) - conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')") - elif self.database_type == StorageBackendType.SQL_SQLITE.value: - # SQLite serializes None as JSON null - conditions.append("(access_attributes IS NULL OR access_attributes = 'null')") - else: - raise ValueError(f"Unsupported database type: {self.database_type}") - return conditions + return ["owner_principal = ''"] def _build_default_policy_where_clause(self, current_user: User | None) -> str: """Build SQL WHERE clause for the default policy. Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] - This means user must match ALL attribute categories that exist in the resource. + This means user must match ANY attribute category that exists in the resource (OR logic). """ base_conditions = self._get_public_access_conditions() - user_attr_conditions = [] - - if current_user and current_user.attributes: - for attr_key, user_values in current_user.attributes.items(): - if user_values: - value_conditions = [] - for value in user_values: - # Check if JSON array contains the value - escaped_value = value.replace("'", "''") - json_text = self._json_extract_text("access_attributes", attr_key) - value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')") - - if value_conditions: - # Check if the category is missing (NULL) - category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL" - user_matches_category = f"({' OR '.join(value_conditions)})" - user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") - - if user_attr_conditions: - all_requirements_met = f"({' AND '.join(user_attr_conditions)})" - base_conditions.append(all_requirements_met) + if current_user: + # Add "user is owner" condition - user's principal matches owner_principal + escaped_principal = current_user.principal.replace("'", "''") + base_conditions.append(f"owner_principal = '{escaped_principal}'") + + # Add "user in owners" conditions for attribute matching + if current_user.attributes: + for attr_key, user_values in current_user.attributes.items(): + if user_values: + value_conditions = [] + for value in user_values: + # Check if JSON array contains the value + escaped_value = value.replace("'", "''") + json_text = self._json_extract_text("access_attributes", attr_key) + value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')") + + if value_conditions: + # User matches this category if any of their values match + user_matches_category = f"({' OR '.join(value_conditions)})" + base_conditions.append(user_matches_category) return f"({' OR '.join(base_conditions)})" def _build_conservative_where_clause(self) -> str: diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index cfa2e35381..31bd3d9022 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -184,6 +184,16 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz f"Category missing logic failed: expected 4,5 but got {category_test_ids}" ) + # Test a user that has all roles and teams (should generate SQL) + # owner_principal = '' + # owner_principal = 'super-user' + # ((JSON_EXTRACT(access_attributes, '$.roles') LIKE '%"admin"%') OR (JSON_EXTRACT(access_attributes, '$.roles') LIKE '%"user"%')) + # ((JSON_EXTRACT(access_attributes, '$.teams') LIKE '%"dev"%') OR (JSON_EXTRACT(access_attributes, '$.teams') LIKE '%"qa"%')) + super_user = User("super-user", {"roles": ["admin", "user"], "teams": ["dev", "qa"]}) + mock_get_authenticated_user.return_value = super_user + result = await authorized_store.fetch_all(table_name) + assert len(result.data) == 6 + finally: # Clean up records await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"]) diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index bf6a24c906..44c58a59fd 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -78,7 +78,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup with pytest.raises(ValueError): await routing_table.get_model("model-data-scientist") - mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]}) + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["user"], "teams": ["other-team"]}) all_models = await routing_table.list_models() assert len(all_models.data) == 1 assert all_models.data[0].identifier == "model-public" @@ -154,16 +154,16 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test ) await registry.register(model) mock_get_authenticated_user.return_value = User( - "test-user", + "differentuser", { "roles": [], }, ) - result = await routing_table.get_model("model-empty-attrs") - assert result.identifier == "model-empty-attrs" + with pytest.raises(ValueError): + await routing_table.get_model("model-empty-attrs") all_models = await routing_table.list_models() model_ids = [m.identifier for m in all_models.data] - assert "model-empty-attrs" in model_ids + assert "model-empty-attrs" not in model_ids @patch("llama_stack.core.routing_tables.common.get_authenticated_user") @@ -223,7 +223,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set assert registered_model.owner.attributes["projects"] == ["llama-3"] # Verify another user without matching attributes can't access it - mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]}) + mock_get_authenticated_user.return_value = User("test-user2", {"roles": ["engineer"], "teams": ["infra-team"]}) with pytest.raises(ValueError): await routing_table.get_model("auto-access-model") @@ -363,6 +363,7 @@ def test_permit_when(): def test_permit_unless(): + # permit unless both conditions are met config = """ - permit: principal: user-1 @@ -377,10 +378,10 @@ def test_permit_unless(): identifier="mymodel", provider_id="myprovider", model_type=ModelType.llm, - owner=User("testuser", {"namespaces": ["foo"]}), + owner=User("testuser", {"namespaces": ["foo"], "teams": ["ml-team"]}), ) assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]})) - assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]})) + assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"], "teams": ["ml-team"]})) assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))