diff --git a/README.md b/README.md index 0dd94d8..85f8620 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@ # toolhive-core +[![Release][release-img]][release] [![Build Status][ci-img]][ci] +[![Coverage Status][coveralls-img]][coveralls] +[![License: Apache 2.0][license-img]][license] +[![Star on GitHub][stars-img]][stars] + The ToolHive Platform common libraries and specifications. `toolhive-core` provides stable, well-tested Go utilities with explicit API guarantees for the ToolHive ecosystem. Projects like [toolhive](https://github.com/stacklok/toolhive), [dockyard](https://github.com/stacklok/dockyard), [toolhive-registry](https://github.com/stacklok/toolhive-registry), and [toolhive-registry-server](https://github.com/stacklok/toolhive-registry-server) depend on this library for shared functionality. @@ -99,3 +104,19 @@ For packages with external dependencies, multiple types, or broader API surface: ## License Apache-2.0 - See [LICENSE](LICENSE) for details. + + + +[release]: https://github.com/stacklok/toolhive-core/releases/latest +[release-img]: https://img.shields.io/github/v/release/stacklok/toolhive-core +[ci]: https://github.com/stacklok/toolhive-core/actions/workflows/ci.yml +[ci-img]: https://github.com/stacklok/toolhive-core/actions/workflows/ci.yml/badge.svg +[coveralls]: https://coveralls.io/github/stacklok/toolhive-core +[coveralls-img]: https://coveralls.io/repos/github/stacklok/toolhive-core/badge.svg +[license]: https://opensource.org/licenses/Apache-2.0 +[license-img]: https://img.shields.io/badge/License-Apache%202.0-blue.svg +[stars]: https://github.com/stacklok/toolhive-core/stargazers +[stars-img]: https://img.shields.io/github/stars/stacklok/toolhive-core?style=social + + + diff --git a/cel/engine_test.go b/cel/engine_test.go index f88fff4..2cb83ff 100644 --- a/cel/engine_test.go +++ b/cel/engine_test.go @@ -414,6 +414,79 @@ func TestCheckError_Details(t *testing.T) { assert.NotEmpty(t, checkErr.Errors) } +func TestEngine_WithMaxExpressionLength(t *testing.T) { + t.Parallel() + + t.Run("rejects expression exceeding custom limit", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithMaxExpressionLength(10) + + _, err := engine.Compile(`claims["sub"] == "user123"`) + require.Error(t, err) + assert.ErrorIs(t, err, cel.ErrExpressionCheck) + }) + + t.Run("accepts expression within custom limit", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithMaxExpressionLength(100) + + expr, err := engine.Compile(`true`) + require.NoError(t, err) + require.NotNil(t, expr) + }) + + t.Run("rejects expression at default limit via Check", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithMaxExpressionLength(5) + + err := engine.Check(`claims["sub"] == "user123"`) + require.Error(t, err) + assert.ErrorIs(t, err, cel.ErrExpressionCheck) + }) +} + +func TestEngine_WithCostLimit(t *testing.T) { + t.Parallel() + + t.Run("returns engine for chaining", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithCostLimit(500000) + require.NotNil(t, engine) + }) + + t.Run("compiles and evaluates within cost limit", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithCostLimit(cel.DefaultCostLimit) + + expr, err := engine.Compile(`claims["sub"] == "user123"`) + require.NoError(t, err) + + ctx := map[string]any{"claims": map[string]any{"sub": "user123"}} + result, err := expr.EvaluateBool(ctx) + require.NoError(t, err) + assert.True(t, result) + }) + + t.Run("zero cost limit rejects evaluation", func(t *testing.T) { + t.Parallel() + + engine := newTestClaimsEngine().WithCostLimit(0) + + expr, err := engine.Compile(`claims["sub"] == "user123"`) + require.NoError(t, err) + + ctx := map[string]any{"claims": map[string]any{"sub": "user123"}} + _, err = expr.EvaluateBool(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, cel.ErrEvaluation) + }) +} + func TestEngine_Concurrency(t *testing.T) { t.Parallel() diff --git a/container/verifier/verifier_test.go b/container/verifier/verifier_test.go new file mode 100644 index 0000000..464fa52 --- /dev/null +++ b/container/verifier/verifier_test.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package verifier + +import ( + "testing" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + registry "github.com/stacklok/toolhive-core/registry/types" +) + +// ------ New ------ + +func TestNew_NilProvenance(t *testing.T) { + t.Parallel() + _, err := New(nil, nil) + assert.ErrorIs(t, err, ErrProvenanceServerInformationNotSet) +} + +// ------ WithKeychain ------ + +func TestWithKeychain_SetsKeychain(t *testing.T) { + t.Parallel() + s := &Sigstore{} + kc := authn.NewMultiKeychain() + got := s.WithKeychain(kc) + assert.Same(t, s, got, "WithKeychain should return the same *Sigstore") + assert.Equal(t, kc, s.keychain) +} + +// ------ GetVerificationResults ------ + +// GetVerificationResults with an unparseable image reference should return an error +// (not ErrProvenanceNotFoundOrIncomplete, so the error propagates directly). +func TestGetVerificationResults_InvalidImageRef(t *testing.T) { + t.Parallel() + s := &Sigstore{keychain: authn.DefaultKeychain} + results, err := s.GetVerificationResults("") + assert.Error(t, err) + assert.Nil(t, results) +} + +// ------ VerifyServer ------ + +// VerifyServer propagates errors from GetVerificationResults. +func TestVerifyServer_PropagatesGetVerificationError(t *testing.T) { + t.Parallel() + s := &Sigstore{keychain: authn.DefaultKeychain} + err := s.VerifyServer("", ®istry.Provenance{}) + assert.Error(t, err) + assert.NotErrorIs(t, err, ErrImageNotSigned) + assert.NotErrorIs(t, err, ErrProvenanceMismatch) +} + +// VerifyServer with a nil provenance still calls GetVerificationResults first; +// if that errors the provenance nil-ness is irrelevant. +func TestVerifyServer_NilProvenance_InvalidRef(t *testing.T) { + t.Parallel() + s := &Sigstore{keychain: authn.DefaultKeychain} + err := s.VerifyServer("", nil) + require.Error(t, err) +} diff --git a/registry/types/registry_types_test.go b/registry/types/registry_types_test.go new file mode 100644 index 0000000..5d10bcc --- /dev/null +++ b/registry/types/registry_types_test.go @@ -0,0 +1,294 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// registryYAML is a realistic registry fixture that exercises the custom +// UnmarshalYAML logic (BaseServerMetadata fields are at the top level, not +// nested) as well as all registry accessor methods. +const registryYAML = ` +version: "1.0.0" +last_updated: "2024-06-01T00:00:00Z" +servers: + server-b: + name: server-b + description: Second container server + tier: Community + status: Active + transport: stdio + image: example/server-b:latest + tools: + - tool_b + metadata: + stars: 5 + last_updated: "2024-01-02T00:00:00Z" + server-a: + name: server-a + description: First container server + tier: Official + status: Active + transport: sse + image: example/server-a:latest + target_port: 8080 + tools: + - tool_a1 + - tool_a2 + tags: + - ai + env_vars: + - name: API_KEY + description: API key + required: true + secret: true +remote_servers: + remote-a: + name: remote-a + description: A remote server + tier: Community + status: Active + transport: streamable-http + url: https://api.example.com/mcp + proxy_port: 9090 + headers: + - name: X-API-Key + description: API key header + required: true + secret: true +groups: + - name: ai-group + description: AI tools group + servers: + server-a: + name: server-a + description: First container server + tier: Official + status: Active + transport: sse + image: example/server-a:latest + remote_servers: + remote-a: + name: remote-a + description: A remote server + tier: Community + status: Active + transport: streamable-http + url: https://api.example.com/mcp + - name: empty-group + description: No servers yet +` + +func parseTestRegistry(t *testing.T) *Registry { + t.Helper() + var reg Registry + require.NoError(t, yaml.Unmarshal([]byte(registryYAML), ®)) + return ® +} + +// TestRegistry_YAMLRoundTrip verifies that the custom UnmarshalYAML correctly +// hydrates both ImageMetadata and RemoteServerMetadata (including embedded +// BaseServerMetadata fields) from a flat YAML document. +func TestRegistry_YAMLRoundTrip(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + // Container server – base fields promoted via UnmarshalYAML + sa := reg.Servers["server-a"] + require.NotNil(t, sa) + assert.Equal(t, "server-a", sa.Name) + assert.Equal(t, "Official", sa.Tier) + assert.Equal(t, "sse", sa.Transport) + assert.Equal(t, "example/server-a:latest", sa.Image) + assert.Equal(t, 8080, sa.TargetPort) + assert.Equal(t, []string{"tool_a1", "tool_a2"}, sa.Tools) + assert.Equal(t, []string{"ai"}, sa.Tags) + require.Len(t, sa.EnvVars, 1) + assert.Equal(t, "API_KEY", sa.EnvVars[0].Name) + assert.True(t, sa.EnvVars[0].Required) + assert.True(t, sa.EnvVars[0].Secret) + + // Remote server – base fields promoted via UnmarshalYAML + ra := reg.RemoteServers["remote-a"] + require.NotNil(t, ra) + assert.Equal(t, "remote-a", ra.Name) + assert.Equal(t, "streamable-http", ra.Transport) + assert.Equal(t, "https://api.example.com/mcp", ra.URL) + assert.Equal(t, 9090, ra.ProxyPort) + require.Len(t, ra.Headers, 1) + assert.Equal(t, "X-API-Key", ra.Headers[0].Name) + assert.True(t, ra.Headers[0].Secret) +} + +// TestRegistry_GetAllServers exercises the unified server listing through +// GetAllServers and confirms IsRemote distinguishes the two kinds. +func TestRegistry_GetAllServers(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + all := reg.GetAllServers() + assert.Len(t, all, 3) // server-a, server-b, remote-a + + var remotes, containers int + for _, s := range all { + if s.IsRemote() { + remotes++ + } else { + containers++ + } + } + assert.Equal(t, 1, remotes) + assert.Equal(t, 2, containers) +} + +// TestRegistry_GetServerByName exercises server lookup and verifies that +// metadata returned through the ServerMetadata interface is correct. +func TestRegistry_GetServerByName(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + tests := []struct { + name string + wantName string + remote bool + found bool + }{ + {"server-a", "server-a", false, true}, + {"remote-a", "remote-a", true, true}, + {"missing", "", false, false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + srv, ok := reg.GetServerByName(tc.name) + assert.Equal(t, tc.found, ok) + if tc.found { + assert.Equal(t, tc.wantName, srv.GetName()) + assert.Equal(t, tc.remote, srv.IsRemote()) + } + }) + } +} + +// TestRegistry_GetAllServers_SortedByName exercises SortServersByName on the +// full server list, ensuring deterministic ordering. +func TestRegistry_GetAllServers_SortedByName(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + all := reg.GetAllServers() + SortServersByName(all) + + names := make([]string, len(all)) + for i, s := range all { + names[i] = s.GetName() + } + assert.Equal(t, []string{"remote-a", "server-a", "server-b"}, names) +} + +// TestRegistry_Groups exercises group lookup and group server enumeration. +func TestRegistry_Groups(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + assert.Len(t, reg.GetAllGroups(), 2) + + g, ok := reg.GetGroupByName("ai-group") + require.True(t, ok) + assert.Equal(t, "AI tools group", g.Description) + + servers := g.GetAllGroupServers() + assert.Len(t, servers, 2) + + _, ok = reg.GetGroupByName("nonexistent") + assert.False(t, ok) +} + +// TestRegistry_EmptyGroup verifies GetAllGroupServers on a group with no servers. +func TestRegistry_EmptyGroup(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + g, ok := reg.GetGroupByName("empty-group") + require.True(t, ok) + assert.Empty(t, g.GetAllGroupServers()) +} + +// TestRegistry_ServerMetadataInterface verifies that iterating all servers via +// the ServerMetadata interface returns sensible values – the way display/filter +// code in consuming packages calls these methods. +func TestRegistry_ServerMetadataInterface(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + all := reg.GetAllServers() + SortServersByName(all) + + // Spot-check server-a through the interface (covers all BaseServerMetadata getters) + var sa ServerMetadata + for _, s := range all { + if s.GetName() == "server-a" { + sa = s + break + } + } + require.NotNil(t, sa) + assert.Equal(t, "First container server", sa.GetDescription()) + assert.Equal(t, "Official", sa.GetTier()) + assert.Equal(t, "Active", sa.GetStatus()) + assert.Equal(t, "sse", sa.GetTransport()) + assert.Equal(t, []string{"tool_a1", "tool_a2"}, sa.GetTools()) + assert.Equal(t, []string{"ai"}, sa.GetTags()) + assert.Equal(t, "", sa.GetTitle()) // not set in fixture + assert.Equal(t, "", sa.GetOverview()) // not set in fixture + assert.Equal(t, "", sa.GetRepositoryURL()) + assert.Nil(t, sa.GetToolDefinitions()) + assert.Nil(t, sa.GetCustomMetadata()) + assert.Nil(t, sa.GetMetadata()) + assert.False(t, sa.IsRemote()) + require.Len(t, sa.GetEnvVars(), 1) + assert.Equal(t, "API_KEY", sa.GetEnvVars()[0].Name) + + // Spot-check remote-a through the interface + var ra ServerMetadata + for _, s := range all { + if s.GetName() == "remote-a" { + ra = s + break + } + } + require.NotNil(t, ra) + assert.True(t, ra.IsRemote()) + assert.Equal(t, "A remote server", ra.GetDescription()) + assert.Empty(t, ra.GetEnvVars()) + + // GetRawImplementation on the concrete type + concrete, ok := reg.RemoteServers["remote-a"] + require.True(t, ok) + assert.Equal(t, concrete, concrete.GetRawImplementation()) +} + +// TestMetadata_ParsedTime exercises ParsedTime through a real server's metadata. +func TestMetadata_ParsedTime(t *testing.T) { + t.Parallel() + reg := parseTestRegistry(t) + + srv := reg.Servers["server-b"] + require.NotNil(t, srv.Metadata) + + ts, err := srv.Metadata.ParsedTime() + require.NoError(t, err) + assert.Equal(t, 2024, ts.Year()) + assert.Equal(t, 2, ts.Day()) + + // Invalid timestamp produces an error + m := &Metadata{LastUpdated: "not-a-date"} + _, err = m.ParsedTime() + assert.Error(t, err) +} diff --git a/registry/types/schema_validation_test.go b/registry/types/schema_validation_test.go index 7269237..3580a48 100644 --- a/registry/types/schema_validation_test.go +++ b/registry/types/schema_validation_test.go @@ -9,6 +9,7 @@ import ( "regexp" "testing" + upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" "github.com/modelcontextprotocol/registry/pkg/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1504,6 +1505,171 @@ func TestValidateSkillBytes(t *testing.T) { } } +// TestRegistry_Validate tests the Validate method on the Registry struct. +func TestRegistry_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + registry *Registry + expectError bool + }{ + { + name: "valid minimal registry", + registry: &Registry{ + Version: "1.0.0", + LastUpdated: "2025-01-01T00:00:00Z", + Servers: map[string]*ImageMetadata{}, + }, + expectError: false, + }, + { + name: "invalid registry - empty struct", + registry: &Registry{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.registry.Validate() + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestUpstreamRegistry_Validate tests the Validate method on the UpstreamRegistry struct. +func TestUpstreamRegistry_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + registry *UpstreamRegistry + expectError bool + }{ + { + name: "valid minimal upstream registry", + registry: &UpstreamRegistry{ + Schema: UpstreamRegistrySchemaURL, + Version: "1.0.0", + Meta: UpstreamMeta{ + LastUpdated: "2025-01-01T00:00:00Z", + }, + Data: UpstreamData{ + Servers: []upstreamv0.ServerJSON{}, + }, + }, + expectError: false, + }, + { + name: "invalid upstream registry - empty struct", + registry: &UpstreamRegistry{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.registry.Validate() + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestServerExtensions_Validate tests the Validate method on the ServerExtensions struct. +func TestServerExtensions_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + extensions *ServerExtensions + expectError bool + }{ + { + name: "invalid extensions - empty struct fails schema structure", + extensions: &ServerExtensions{}, + expectError: true, + }, + { + name: "invalid extensions - non-empty struct fails schema structure", + extensions: &ServerExtensions{Status: "active", Tier: "Official"}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.extensions.Validate() + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestSkill_Validate tests the Validate method on the Skill struct. +func TestSkill_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + skill *Skill + expectError bool + }{ + { + name: "valid skill", + skill: &Skill{ + Namespace: "io.github.example", + Name: "my-skill", + Description: "A test skill for validation", + Version: "1.0.0", + }, + expectError: false, + }, + { + name: "invalid skill - empty struct", + skill: &Skill{}, + expectError: true, + }, + { + name: "invalid skill - invalid status enum", + skill: &Skill{ + Namespace: "io.github.example", + Name: "my-skill", + Description: "A test skill for validation", + Version: "1.0.0", + Status: "invalid-status", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.skill.Validate() + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + // TestUpstreamRegistrySchemaVersionSync ensures that the schema reference in // upstream-registry.schema.json matches the schema version from the Go package // (model.CurrentSchemaVersion). This prevents schema drift when upgrading the