Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use serde::{Deserialize, Serialize};
use worker::{Env, Model, Request, Response, Result, StreamableModel};

use crate::SomeSharedData;

pub struct Llama4Scout17b16eInstruct;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed this seems like the right approach!


#[derive(Serialize)]
pub struct DefaultTextGenerationInput {
pub prompt: String,
}

#[derive(Deserialize)]
pub struct DefaultTextGenerationOutput {
pub response: String,
}

impl From<DefaultTextGenerationOutput> for Vec<u8> {
fn from(value: DefaultTextGenerationOutput) -> Self {
value.response.into_bytes()
}
}

impl Model for Llama4Scout17b16eInstruct {
const MODEL_NAME: &str = "@cf/meta/llama-4-scout-17b-16e-instruct";

type Input = DefaultTextGenerationInput;

type Output = DefaultTextGenerationOutput;
}

impl StreamableModel for Llama4Scout17b16eInstruct {}

const AI_TEST: &str = "AI_TEST";

#[worker::send]
pub async fn simple_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let ai = env
.ai(AI_TEST)?
.run::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput {
prompt: "What is the answer to life the universe and everything?".to_owned(),
})
.await?;
Response::ok(ai.response)
}

#[worker::send]
pub async fn streaming_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let stream = env
.ai(AI_TEST)?
.run_streaming::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput {
prompt: "What is the answer to life the universe and everything?".to_owned(),
})
.await?;

Response::from_stream(stream)
}
1 change: 1 addition & 0 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use worker::{console_log, event, js_sys, wasm_bindgen, Env, Result};
#[cfg(not(feature = "http"))]
use worker::{Request, Response};

mod ai;
mod alarm;
mod analytics_engine;
mod assets;
Expand Down
8 changes: 5 additions & 3 deletions test/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch,
form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, sql_counter,
sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
ai, alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable,
fetch, form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket,
sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
};
#[cfg(feature = "http")]
use std::convert::TryInto;
Expand Down Expand Up @@ -112,6 +112,8 @@ macro_rules! add_route (

macro_rules! add_routes (
($obj:ident) => {
add_route!($obj, get, "/ai", ai::simple_ai_text_generation);
add_route!($obj, get, "/ai/streaming", ai::streaming_ai_text_generation);
add_route!($obj, get, sync, "/request", request::handle_a_request);
add_route!($obj, get, "/analytics-engine", analytics_engine::handle_analytics_event);
add_route!($obj, get, "/async-request", request::handle_async_request);
Expand Down
11 changes: 11 additions & 0 deletions test/tests/ai.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { describe, expect, test } from "vitest";
import { mf, mfUrl } from "./mf";

async function runTest() {
let normal_response = await mf.dispatchFetch(`${mfUrl}/ai`);
expect(normal_response.status).toBe(200);

Check failure on line 6 in test/tests/ai.spec.ts

View workflow job for this annotation

GitHub Actions / Test

tests/ai.spec.ts

AssertionError: expected 404 to be 200 // Object.is equality - Expected + Received - 200 + 404 ❯ runTest tests/ai.spec.ts:6:34

let streaming_response = await mf.dispatchFetch(`${mfUrl}/ai/streaming`);
expect(streaming_response.status).toBe(200);
}
describe("ai", runTest);
25 changes: 14 additions & 11 deletions test/wrangler.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "testing-rust-worker"
workers_dev = true
compatibility_date = "2025-09-23" # required
compatibility_date = "2025-09-23" # required
main = "build/worker/shim.mjs"

kv_namespaces = [
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
]

[vars]
Expand All @@ -22,14 +22,14 @@ service = "remote-service"

[durable_objects]
bindings = [
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
]

[[analytics_engine_datasets]]
Expand Down Expand Up @@ -84,3 +84,6 @@ secret_name = "secret-name"
class_name = "EchoContainer"
image = "./container-echo/Dockerfile"
max_instances = 1

[ai]
binding = "AI_TEST"
103 changes: 95 additions & 8 deletions worker/src/ai.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{self, Poll};

use crate::{env::EnvBinding, send::SendFuture};
use crate::{Error, Result};
use serde::de::DeserializeOwned;
use serde::Serialize;
use futures_util::io::{BufReader, Lines};
use futures_util::{ready, AsyncBufReadExt as _, Stream, StreamExt as _};
use js_sys::Reflect;
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use wasm_streams::readable::IntoAsyncRead;
use worker_sys::Ai as AiSys;

/// Enables access to Workers AI functionality.
Expand All @@ -14,20 +22,27 @@ impl Ai {
/// Execute a Workers AI operation using the specified model.
/// Various forms of the input are documented in the Workers
/// AI documentation.
pub async fn run<T: Serialize, U: DeserializeOwned>(
&self,
model: impl AsRef<str>,
input: T,
) -> Result<U> {
pub async fn run<M: Model>(&self, input: M::Input) -> Result<M::Output> {
let fut = SendFuture::new(JsFuture::from(
self.0
.run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
.run(M::MODEL_NAME, serde_wasm_bindgen::to_value(&input)?),
));
match fut.await {
Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
Err(err) => Err(Error::from(err)),
}
}

pub async fn run_streaming<M: StreamableModel>(&self, input: M::Input) -> Result<AiStream<M>> {
let input = serde_wasm_bindgen::to_value(&input)?;
Reflect::set(&input, &JsValue::from_str("stream"), &JsValue::TRUE)?;

let fut = SendFuture::new(JsFuture::from(self.0.run(M::MODEL_NAME, input)));
let raw_stream = fut.await?.dyn_into::<web_sys::ReadableStream>()?;
let stream = wasm_streams::ReadableStream::from_raw(raw_stream).into_async_read();

Ok(AiStream::new(stream))
}
}

unsafe impl Sync for Ai {}
Expand Down Expand Up @@ -82,3 +97,75 @@ impl EnvBinding for Ai {
}
}
}

pub trait Model: 'static {
const MODEL_NAME: &str;
type Input: Serialize;
type Output: DeserializeOwned;
}

pub trait StreamableModel: Model {}

#[derive(Debug)]
#[pin_project]
pub struct AiStream<T: StreamableModel> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said I'm still not clear why AiStream is unique from a generic stream. Couldn't we just have TypedReadeableStream like you did for array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we can I'll do what I said previously here and type alias it so we can document the output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work on this tommorow

#[pin]
inner: Lines<BufReader<IntoAsyncRead<'static>>>,
phantom: PhantomData<T>,
}

impl<T: StreamableModel> AiStream<T> {
pub fn new(stream: IntoAsyncRead<'static>) -> Self {
Self {
inner: BufReader::new(stream).lines(),
phantom: PhantomData,
}
}
}

impl<T: StreamableModel> Stream for AiStream<T> {
type Item = Result<T::Output>;

fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let string = match ready!(this.inner.poll_next_unpin(cx)) {
Some(item) => match item {
Ok(item) => {
if item.is_empty() {
match ready!(this.inner.poll_next_unpin(cx)) {
Some(item) => match item {
Ok(item) => item,
Err(err) => {
return Poll::Ready(Some(Err(err.into())));
}
},
None => {
return Poll::Ready(None);
}
}
} else {
item
}
}
Err(err) => {
return Poll::Ready(Some(Err(err.into())));
}
},
None => {
return Poll::Ready(None);
}
};

let string = if let Some(string) = string.strip_prefix("data: ") {
string
} else {
string.as_str()
};

if string == "[DONE]" {
return Poll::Ready(None);
}

Poll::Ready(Some(Ok(serde_json::from_str(string)?)))
}
}
Loading