Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |

#### Cerebras

Cerebras 所对应的 `type` 为 `cerebras`。它并无特有的配置字段。

## 用法示例

### 使用 OpenAI 协议代理 Azure OpenAI 服务
Expand Down Expand Up @@ -2100,11 +2104,61 @@ providers:
}
```

### 使用 OpenAI 协议代理 Cerebras 服务

**配置信息**

```yaml
provider:
type: cerebras
apiTokens:
- "YOUR_CEREBRAS_API_TOKEN"
modelMapping:
"gpt-4": "llama3.1-70b"
"gpt-3.5-turbo": "llama3.1-8b"
"*": "llama3.1-8b"
```

**请求示例**

```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```

**响应示例**

```json
{
"id": "cmpl-123456789",
"object": "chat.completion",
"created": 1699123456,
"model": "llama3.1-70b",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是一个由 Cerebras 提供的 AI 助手,基于 Llama 3.1 模型。我可以帮助回答问题、进行对话和提供各种信息。有什么我可以帮助你的吗?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 50,
"total_tokens": 60
}
}
```

## 完整配置示例

Expand Down
75 changes: 74 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ The plugin now supports **automatic protocol detection**, allowing seamless comp
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.

## Execution Properties

Plugin execution phase: `Default Phase`
Plugin execution priority: `100`


## Configuration Fields

### Basic Configuration
Expand Down Expand Up @@ -243,6 +243,7 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |

#### Google Vertex AI

For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:

| Name | Data Type | Requirement | Default | Description |
Expand All @@ -265,6 +266,10 @@ For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |

#### Cerebras

For Cerebras, the corresponding `type` is `cerebras`. It has no unique configuration fields.

## Usage Examples

### Using OpenAI Protocol Proxy for Azure OpenAI Service
Expand Down Expand Up @@ -1657,6 +1662,7 @@ Here, `model` denotes the service tier of DeepL and can only be either `Free` or
### Utilizing OpenAI Protocol Proxy for Together-AI Services

**Configuration Information**

```yaml
provider:
type: together-ai
Expand All @@ -1667,6 +1673,7 @@ provider:
```

**Request Example**

```json
{
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
Expand All @@ -1680,6 +1687,7 @@ provider:
```

**Response Example**

```json
{
"id": "8f5809d54b73efac",
Expand Down Expand Up @@ -1709,7 +1717,9 @@ provider:
```

### Utilizing OpenAI Protocol Proxy for Google Vertex Services

**Configuration Information**

```yaml
provider:
type: vertex
Expand All @@ -1728,6 +1738,7 @@ provider:
```

**Request Example**

```json
{
"model": "gemini-2.0-flash-001",
Expand All @@ -1742,6 +1753,7 @@ provider:
```

**Response Example**

```json
{
"id": "chatcmpl-0000000000000",
Expand All @@ -1767,7 +1779,9 @@ provider:
```

### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services

**Configuration Information**

```yaml
provider:
type: bedrock
Expand All @@ -1779,6 +1793,7 @@ provider:
```

**Request Example**

```json
{
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
Expand All @@ -1793,6 +1808,7 @@ provider:
```

**Response Example**

```json
{
"id": "d52da49d-daf3-49d9-a105-0b527481fe14",
Expand All @@ -1817,6 +1833,62 @@ provider:
}
```

### Utilizing OpenAI Protocol Proxy for Cerebras Services

**Configuration Information**

```yaml
provider:
type: cerebras
apiTokens:
- "YOUR_CEREBRAS_API_TOKEN"
modelMapping:
"gpt-4": "llama3.1-70b"
"gpt-3.5-turbo": "llama3.1-8b"
"*": "llama3.1-8b"
```

**Request Example**

```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "who are you"
}
],
"stream": false
}
```

**Response Example**

```json
{
"id": "cmpl-123456789",
"object": "chat.completion",
"created": 1699123456,
"model": "llama3.1-70b",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am an AI assistant powered by Cerebras, based on the Llama 3.1 model. I can help answer questions, engage in conversations, and provide various information. How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 50,
"total_tokens": 60
}
}
```

### Utilizing OpenAI Protocol Proxy for NVIDIA Triton Interference Server Services

**Configuration Information**
Expand Down Expand Up @@ -1846,6 +1918,7 @@ providers:
"stream": false
}
```

**Response Example**

```json
Expand Down
9 changes: 9 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,12 @@ func TestFireworks(t *testing.T) {
test.RunFireworksOnHttpRequestHeadersTests(t)
test.RunFireworksOnHttpRequestBodyTests(t)
}

func TestCerebras(t *testing.T) {
test.RunCerebrasParseConfigTests(t)
test.RunCerebrasOnHttpRequestHeadersTests(t)
test.RunCerebrasOnHttpRequestBodyTests(t)
test.RunCerebrasOnHttpResponseHeadersTests(t)
test.RunCerebrasOnHttpResponseBodyTests(t)
test.RunCerebrasOnStreamingResponseBodyTests(t)
}
119 changes: 119 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/cerebras.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package provider

import (
"errors"
"net/http"
"path"
"strings"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)

// cerebrasProvider is the provider for Cerebras service.

const (
defaultCerebrasDomain = "api.cerebras.ai"
)

type cerebrasProviderInitializer struct{}

func (c *cerebrasProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

func (c *cerebrasProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameModels): PathOpenAIModels,
}
}

func (c *cerebrasProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl != "" {
// Handle custom URL like OpenAI
customUrl := strings.TrimPrefix(strings.TrimPrefix(config.openaiCustomUrl, "http://"), "https://")
pairs := strings.SplitN(customUrl, "/", 2)
customPath := "/"
if len(pairs) == 2 {
customPath += pairs[1]
}
capabilities := c.DefaultCapabilities()
for key, mapPath := range capabilities {
capabilities[key] = path.Join(customPath, strings.TrimPrefix(mapPath, "/v1"))
}
config.setDefaultCapabilities(capabilities)
log.Debugf("ai-proxy: cerebras provider customDomain:%s, customPath:%s, capabilities:%v",
pairs[0], customPath, capabilities)
return &cerebrasProvider{
config: config,
contextCache: createContextCache(&config),
customDomain: pairs[0],
customPath: customPath,
}, nil
}

// Set default capabilities
config.setDefaultCapabilities(c.DefaultCapabilities())

return &cerebrasProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}

type cerebrasProvider struct {
config ProviderConfig
contextCache *contextCache
customDomain string
customPath string
}

func (p *cerebrasProvider) GetProviderType() string {
return providerTypeCerebras
}

func (p *cerebrasProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
p.config.handleRequestHeaders(p, ctx, apiName)
return nil
}

func (p *cerebrasProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if p.customPath != "" {
util.OverwriteRequestPathHeader(headers, p.customPath)
} else if apiName != "" {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities)
}

if p.customDomain != "" {
util.OverwriteRequestHostHeader(headers, p.customDomain)
} else {
util.OverwriteRequestHostHeader(headers, defaultCerebrasDomain)
}
if len(p.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
}
headers.Del("Content-Length")
}

func (p *cerebrasProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !p.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body)
}

func (p *cerebrasProvider) GetApiName(path string) ApiName {
if strings.Contains(path, PathOpenAIChatCompletions) {
return ApiNameChatCompletion
}
if strings.Contains(path, PathOpenAIModels) {
return ApiNameModels
}
return ""
}
2 changes: 2 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ const (
providerTypeOpenRouter = "openrouter"
providerTypeLongcat = "longcat"
providerTypeFireworks = "fireworks"
providerTypeCerebras = "cerebras"

protocolOpenAI = "openai"
protocolOriginal = "original"
Expand Down Expand Up @@ -217,6 +218,7 @@ var (
providerTypeOpenRouter: &openrouterProviderInitializer{},
providerTypeLongcat: &longcatProviderInitializer{},
providerTypeFireworks: &fireworksProviderInitializer{},
providerTypeCerebras: &cerebrasProviderInitializer{},
}
)

Expand Down
Loading
Loading