diff --git a/test/src/ai.rs b/test/src/ai.rs new file mode 100644 index 000000000..2acc4f364 --- /dev/null +++ b/test/src/ai.rs @@ -0,0 +1,41 @@ +use worker::{ + models::llama_4_scout_17b_16e_instruct::Llama4Scout17b16eInstruct, + worker_sys::AiTextGenerationInput, Env, Request, Response, Result, +}; + +use crate::SomeSharedData; + +const AI_TEST: &str = "AI_TEST"; + +#[worker::send] +pub async fn simple_ai_text_generation( + _: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let ai = env + .ai(AI_TEST)? + .run::( + AiTextGenerationInput::new() + .set_prompt("What is the answer to life the universe and everything?"), + ) + .await?; + Response::ok(ai.get_response().unwrap_or_default()) +} + +#[worker::send] +pub async fn streaming_ai_text_generation( + _: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let stream = env + .ai(AI_TEST)? + .run_streaming::( + AiTextGenerationInput::new() + .set_prompt("What is the answer to life the universe and everything?"), + ) + .await?; + + Response::from_stream(stream) +} diff --git a/test/src/lib.rs b/test/src/lib.rs index 91b082cf8..86b3d9eac 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -12,6 +12,7 @@ use worker::{console_log, event, js_sys, wasm_bindgen, Env, Result}; #[cfg(not(feature = "http"))] use worker::{Request, Response}; +mod ai; mod alarm; mod analytics_engine; mod assets; diff --git a/test/src/router.rs b/test/src/router.rs index ff1d2c2ac..f8c3ff933 100644 --- a/test/src/router.rs +++ b/test/src/router.rs @@ -1,7 +1,7 @@ use crate::{ - alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch, - form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, sql_counter, - sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE, + ai, alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, + fetch, form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, + sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE, }; #[cfg(feature = "http")] use std::convert::TryInto; @@ -112,6 +112,8 @@ macro_rules! add_route ( macro_rules! add_routes ( ($obj:ident) => { + add_route!($obj, get, "/ai", ai::simple_ai_text_generation); + add_route!($obj, get, "/ai/streaming", ai::streaming_ai_text_generation); add_route!($obj, get, sync, "/request", request::handle_a_request); add_route!($obj, get, "/analytics-engine", analytics_engine::handle_analytics_event); add_route!($obj, get, "/async-request", request::handle_async_request); diff --git a/test/tests/ai.spec.ts b/test/tests/ai.spec.ts new file mode 100644 index 000000000..c07c6e824 --- /dev/null +++ b/test/tests/ai.spec.ts @@ -0,0 +1,11 @@ +import { describe, expect, test } from "vitest"; +import { mf, mfUrl } from "./mf"; + +async function runTest() { + let normal_response = await mf.dispatchFetch(`${mfUrl}ai`); + expect(normal_response.status).toBe(200); + + let streaming_response = await mf.dispatchFetch(`${mfUrl}ai/streaming`); + expect(streaming_response.status).toBe(200); +} +describe("ai", runTest); diff --git a/test/wrangler.toml b/test/wrangler.toml index 1a1c439b9..fcc913259 100644 --- a/test/wrangler.toml +++ b/test/wrangler.toml @@ -1,11 +1,11 @@ name = "testing-rust-worker" workers_dev = true -compatibility_date = "2025-09-23" # required +compatibility_date = "2025-09-23" # required main = "build/worker/shim.mjs" kv_namespaces = [ - { binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" }, - { binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" }, + { binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" }, + { binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" }, ] [vars] @@ -22,14 +22,14 @@ service = "remote-service" [durable_objects] bindings = [ - { name = "COUNTER", class_name = "Counter" }, - { name = "ALARM", class_name = "AlarmObject" }, - { name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" }, - { name = "AUTO", class_name = "AutoResponseObject" }, - { name = "SQL_COUNTER", class_name = "SqlCounter" }, - { name = "SQL_ITERATOR", class_name = "SqlIterator" }, - { name = "MY_CLASS", class_name = "MyClass" }, - { name = "ECHO_CONTAINER", class_name = "EchoContainer" }, + { name = "COUNTER", class_name = "Counter" }, + { name = "ALARM", class_name = "AlarmObject" }, + { name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" }, + { name = "AUTO", class_name = "AutoResponseObject" }, + { name = "SQL_COUNTER", class_name = "SqlCounter" }, + { name = "SQL_ITERATOR", class_name = "SqlIterator" }, + { name = "MY_CLASS", class_name = "MyClass" }, + { name = "ECHO_CONTAINER", class_name = "EchoContainer" }, ] [[analytics_engine_datasets]] @@ -84,3 +84,6 @@ secret_name = "secret-name" class_name = "EchoContainer" image = "./container-echo/Dockerfile" max_instances = 1 + +[ai] +binding = "AI_TEST" diff --git a/worker-sys/Cargo.toml b/worker-sys/Cargo.toml index a44baacd2..612b1302e 100644 --- a/worker-sys/Cargo.toml +++ b/worker-sys/Cargo.toml @@ -10,34 +10,35 @@ description = "Low-level extern definitions / FFI bindings to the Cloudflare Wor [dependencies] js-sys.workspace = true wasm-bindgen.workspace = true + cfg-if = "1.0.1" [dependencies.web-sys] version = "0.3.70" features = [ - "ReadableStream", - "WritableStream", - "RequestRedirect", - "RequestInit", - "FormData", - "Blob", - "BinaryType", - "ErrorEvent", - "MessageEvent", - "CloseEvent", - "ProgressEvent", - "WebSocket", - "TransformStream", - "AbortController", - "console", - "ResponseInit", - "Cache", - "CacheStorage", - "CacheQueryOptions", - "AbortSignal", - "Headers", - "Request", - "Response", + "ReadableStream", + "WritableStream", + "RequestRedirect", + "RequestInit", + "FormData", + "Blob", + "BinaryType", + "ErrorEvent", + "MessageEvent", + "CloseEvent", + "ProgressEvent", + "WebSocket", + "TransformStream", + "AbortController", + "console", + "ResponseInit", + "Cache", + "CacheStorage", + "CacheQueryOptions", + "AbortSignal", + "Headers", + "Request", + "Response", ] [features] diff --git a/worker-sys/src/types.rs b/worker-sys/src/types.rs index d6e63d4db..7861d282f 100644 --- a/worker-sys/src/types.rs +++ b/worker-sys/src/types.rs @@ -45,3 +45,4 @@ pub use tls_client_auth::*; pub use version::*; pub use websocket_pair::*; pub use websocket_request_response_pair::*; +pub mod utils; diff --git a/worker-sys/src/types/ai.rs b/worker-sys/src/types/ai.rs index f362138ba..a3a4a3d78 100644 --- a/worker-sys/src/types/ai.rs +++ b/worker-sys/src/types/ai.rs @@ -1,6 +1,12 @@ -use js_sys::Promise; +use std::iter::FromIterator; + +use js_sys::{Array, Promise}; use wasm_bindgen::prelude::*; +use crate::typed_array; + +use super::utils::typed_array::TypedArrayBuilder; + #[wasm_bindgen] extern "C" { #[wasm_bindgen(extends=::js_sys::Object, js_name=Ai)] @@ -10,3 +16,226 @@ extern "C" { #[wasm_bindgen(structural, method, js_class=Ai, js_name=run)] pub fn run(this: &Ai, model: &str, input: JsValue) -> Promise; } + +typed_array!(RoleScopedChatInputArray, RoleScopedChatInput); + +#[wasm_bindgen] +extern "C" { + # [wasm_bindgen (extends = :: js_sys :: Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type RoleScopedChatInput; + + #[wasm_bindgen(constructor, js_class = Object)] + fn new() -> RoleScopedChatInput; + + #[wasm_bindgen(method, setter = "role")] + fn set_role_inner(this: &RoleScopedChatInput, role: &str); + #[wasm_bindgen(method, getter = "role")] + fn get_role_inner(this: &RoleScopedChatInput) -> Option; + + #[wasm_bindgen(method, setter = "content")] + fn set_content_inner(this: &RoleScopedChatInput, content: &str); + #[wasm_bindgen(method, getter = "content")] + fn get_content_inner(this: &RoleScopedChatInput) -> Option; +} + +#[derive(Default, Debug)] +pub enum Role { + #[default] + User, + Assistant, + System, + Tool, + Any(String), +} + +impl RoleScopedChatInput { + pub fn get_role(&self) -> Role { + match self.get_role_inner().as_deref() { + Some("user") => Role::User, + Some("assistant") => Role::Assistant, + Some("system") => Role::System, + Some("tool") => Role::Tool, + Some(any) => Role::Any(any.to_owned()), + None => Role::default(), + } + } + + pub fn get_content(&self) -> String { + self.get_content_inner().unwrap_or_default() + } + + pub fn custom_role(role: &str, content: &str) -> Self { + let message = RoleScopedChatInput::new(); + message.set_role_inner(role); + message.set_content_inner(content); + message + } + + pub fn user(content: &str) -> Self { + Self::custom_role("user", content) + } + + pub fn assistant(content: &str) -> Self { + Self::custom_role("assistant", content) + } + + pub fn system(content: &str) -> Self { + Self::custom_role("system", content) + } + + pub fn tool(content: &str) -> Self { + Self::custom_role("tool", content) + } + + pub fn builder<'a>() -> TypedArrayBuilder<'a, RoleScopedChatInputArray::RoleScopedChatInputArray> + { + TypedArrayBuilder::new() + } +} + +#[wasm_bindgen] +extern "C" { + # [wasm_bindgen (extends = :: js_sys :: Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type AiTextGenerationInput; + + #[wasm_bindgen(constructor, js_class = Object)] + pub fn new() -> AiTextGenerationInput; + + #[wasm_bindgen(method, setter = "prompt")] + fn set_prompt_inner(this: &AiTextGenerationInput, prompt: &str); + #[wasm_bindgen(method, getter = "prompt")] + pub fn get_prompt(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "raw")] + fn set_raw_inner(this: &AiTextGenerationInput, raw: bool); + #[wasm_bindgen(method, getter = "raw")] + pub fn get_raw(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "max_tokens")] + fn set_max_tokens_inner(this: &AiTextGenerationInput, max_tokens: u32); + #[wasm_bindgen(method, getter = "max_tokens")] + pub fn get_max_tokens(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "temperature")] + fn set_temperature_inner(this: &AiTextGenerationInput, temperature: f32); + #[wasm_bindgen(method, getter = "temperature")] + pub fn get_temperature(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "top_p")] + fn set_top_p_inner(this: &AiTextGenerationInput, top_p: f32); + #[wasm_bindgen(method, getter = "top_p")] + pub fn get_top_p(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "top_k")] + fn set_top_k_inner(this: &AiTextGenerationInput, top_p: u32); + #[wasm_bindgen(method, getter = "top_k")] + pub fn get_top_k(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "seed")] + fn set_seed_inner(this: &AiTextGenerationInput, seed: u64); + #[wasm_bindgen(method, getter = "seed")] + pub fn get_seed(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "repetition_penalty")] + fn set_repetition_penalty_inner(this: &AiTextGenerationInput, repetition_penalty: f32); + #[wasm_bindgen(method, getter = "repetition_penalty")] + pub fn get_repetition_penalty(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "frequency_penalty")] + fn set_frequency_penalty_inner(this: &AiTextGenerationInput, frequency_penalty: f32); + #[wasm_bindgen(method, getter = "frequency_penalty")] + pub fn get_frequency_penalty(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "presence_penalty")] + fn set_presence_penalty_inner(this: &AiTextGenerationInput, presence_penalty: f32); + #[wasm_bindgen(method, getter = "presence_penalty")] + pub fn get_presence_penalty(this: &AiTextGenerationInput) -> Option; + + #[wasm_bindgen(method, setter = "messages")] + fn set_messages_inner(this: &AiTextGenerationInput, messages: Array); + #[wasm_bindgen(method, getter = "messages")] + pub fn get_messages(this: &AiTextGenerationInput) -> Option>; + +} + +impl AiTextGenerationInput { + pub fn set_prompt(self, prompt: &str) -> Self { + self.set_prompt_inner(prompt); + self + } + + pub fn set_raw(self, raw: bool) -> Self { + self.set_raw_inner(raw); + self + } + + pub fn set_max_tokens(self, max_tokens: u32) -> Self { + self.set_max_tokens_inner(max_tokens); + self + } + + pub fn set_temperature(self, temperature: f32) -> Self { + self.set_temperature_inner(temperature); + self + } + + pub fn set_top_p(self, top_p: f32) -> Self { + self.set_top_p_inner(top_p); + self + } + + pub fn set_top_k(self, top_k: u32) -> Self { + self.set_top_k_inner(top_k); + self + } + + pub fn set_seed(self, seed: u64) -> Self { + self.set_seed_inner(seed); + self + } + + pub fn set_repetition_penalty(self, repetition_penalty: f32) -> Self { + self.set_repetition_penalty_inner(repetition_penalty); + self + } + + pub fn set_frequency_penalty(self, frequency_penalty: f32) -> Self { + self.set_frequency_penalty_inner(frequency_penalty); + self + } + + pub fn set_presence_penalty(self, presence_penalty: f32) -> Self { + self.set_presence_penalty_inner(presence_penalty); + self + } + + pub fn set_messages(self, messages: &[RoleScopedChatInput]) -> Self { + self.set_messages_inner(Array::from_iter(messages)); + self + } +} + +#[wasm_bindgen] +extern "C" { + # [wasm_bindgen (extends = :: js_sys :: Object)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type AiTextGenerationOutput; + + #[wasm_bindgen(constructor, js_class = Object)] + pub fn new() -> AiTextGenerationOutput; + + #[wasm_bindgen(method, getter = "response")] + pub fn get_response(this: &AiTextGenerationOutput) -> Option; + +} + +impl From for Vec { + fn from(value: AiTextGenerationOutput) -> Self { + value + .get_response() + .map(|text| text.into_bytes()) + .unwrap_or_default() + } +} diff --git a/worker-sys/src/types/utils.rs b/worker-sys/src/types/utils.rs new file mode 100644 index 000000000..1a9fbcc44 --- /dev/null +++ b/worker-sys/src/types/utils.rs @@ -0,0 +1 @@ +pub mod typed_array; diff --git a/worker-sys/src/types/utils/typed_array.rs b/worker-sys/src/types/utils/typed_array.rs new file mode 100644 index 000000000..577bc682b --- /dev/null +++ b/worker-sys/src/types/utils/typed_array.rs @@ -0,0 +1,435 @@ +pub(crate) trait TypedArray { + type Item; + + fn new_with_length(len: u32) -> Self; + + fn push(&self, item: &Self::Item) -> u32; +} + +#[allow(private_bounds)] +pub struct TypedArrayBuilder<'a, T: TypedArray> { + item: Option<&'a T::Item>, + builder: Option<&'a TypedArrayBuilder<'a, T>>, + index: Option, +} + +impl<'a, T: TypedArray> Default for TypedArrayBuilder<'a, T> { + fn default() -> Self { + Self { + item: None, + builder: None, + index: Some(0), + } + } +} + +#[allow(private_bounds)] +impl<'a, T: TypedArray> TypedArrayBuilder<'a, T> { + pub fn new() -> Self { + Default::default() + } + + pub fn push(mut self, item: impl Into<&'a T::Item>) -> TypedArrayBuilder<'a, T> { + TypedArrayBuilder { + item: Some(item.into()), + index: self.index.take().map(|x| x + 1), + builder: self.builder.take(), + } + } + + pub fn build(self) -> T { + let vec = T::new_with_length(self.index.unwrap()); + let mut builder_option = self.builder; + let mut item_option = self.item; + while let Some((item, builder)) = item_option.take().zip(builder_option.take()) { + vec.push(item); + builder_option = builder.builder + } + vec + } +} + +#[macro_export] +macro_rules! typed_array { + ($name:ident, $type:ident) => { + #[allow(non_snake_case)] + mod $name { + use super::$type; + use ::wasm_bindgen::prelude::*; + use $crate::utils::typed_array::TypedArray; + + #[wasm_bindgen] + extern "C" { + #[wasm_bindgen (extends = :: js_sys :: Array)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type $name; + + #[wasm_bindgen(constructor, js_class = Array)] + fn new() -> $name; + + #[wasm_bindgen(constructor)] + fn new_with_length(len: u32) -> $name; + + #[wasm_bindgen(method)] + pub fn at(this: &$name, index: i32) -> Option<$type>; + + #[wasm_bindgen(method, structural, indexing_getter)] + pub fn get(this: &$name, index: u32) -> Option<$type>; + + #[wasm_bindgen(method, structural, indexing_setter)] + pub fn set(this: &$name, index: u32, value: $type); + + #[wasm_bindgen(method, structural, indexing_deleter)] + pub fn delete(this: &$name, index: u32); + + #[wasm_bindgen(static_method_of = $name)] + pub fn from(val: &$name) -> $name; + + #[wasm_bindgen(method, js_name = copyWithin)] + pub fn copy_within(this: &$name, target: i32, start: i32, end: i32) -> $name; + + #[wasm_bindgen(method)] + pub fn concat(this: &$name, array: &$name) -> $name; + + #[wasm_bindgen(method)] + pub fn every( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> bool; + + #[wasm_bindgen(method)] + pub fn fill(this: &$name, value: &$type, start: u32, end: u32) -> $name; + + #[wasm_bindgen(method)] + pub fn filter( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> $name; + + #[wasm_bindgen(method)] + pub fn find( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> $name; + + #[wasm_bindgen(method, js_name = findIndex)] + pub fn find_index( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> i32; + + #[wasm_bindgen(method, js_name = findLast)] + pub fn find_last( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> $type; + + #[wasm_bindgen(method, js_name = findLastIndex)] + pub fn find_last_index( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> bool, + ) -> i32; + + #[wasm_bindgen(method)] + pub fn flat(this: &$name, depth: i32) -> $name; + + #[wasm_bindgen(method, js_name = flatMap)] + pub fn flat_map( + this: &$name, + callback: &mut dyn FnMut($type, u32, $name) -> Vec<$type>, + ) -> $name; + + #[wasm_bindgen(method, js_name = forEach)] + pub fn for_each(this: &$name, callback: &mut dyn FnMut($type, u32, $name)); + + #[wasm_bindgen(method)] + pub fn includes(this: &$name, value: &$type, from_index: i32) -> bool; + + #[wasm_bindgen(method, js_name = indexOf)] + pub fn index_of(this: &$name, value: &$type, from_index: i32) -> i32; + + #[wasm_bindgen(static_method_of = $name, js_name = isArray)] + pub fn is_array(value: &$type) -> bool; + + #[wasm_bindgen(method)] + pub fn join(this: &$name, delimiter: &str) -> ::js_sys::JsString; + + #[wasm_bindgen(method, js_name = lastIndexOf)] + pub fn last_index_of(this: &$name, value: &$type, from_index: i32) -> i32; + + #[wasm_bindgen(method, getter, structural)] + pub fn length(this: &$name) -> u32; + + #[wasm_bindgen(method, setter)] + pub fn set_length(this: &$name, value: u32); + + #[wasm_bindgen(method)] + pub fn map( + this: &$name, + predicate: &mut dyn FnMut($type, u32, $name) -> $type, + ) -> $name; + + #[wasm_bindgen(static_method_of = $name, js_name = of)] + pub fn of1(a: &$type) -> $name; + + #[wasm_bindgen(static_method_of = $name, js_name = of)] + pub fn of2(a: &$type, b: &$type) -> $name; + + #[wasm_bindgen(static_method_of = $name, js_name = of)] + pub fn of3(a: &$type, b: &$type, c: &$type) -> $name; + + #[wasm_bindgen(static_method_of = $name, js_name = of)] + pub fn of4(a: &$type, b: &$type, c: &$type, d: &$type) -> $name; + + #[wasm_bindgen(static_method_of = $name, js_name = of)] + pub fn of5(a: &$type, b: &$type, c: &$type, d: &$type, e: &$type) -> $name; + + #[wasm_bindgen(method)] + pub fn pop(this: &$name) -> $type; + + #[wasm_bindgen(method)] + pub fn push(this: &$name, value: &$type) -> u32; + + #[wasm_bindgen(method)] + pub fn reduce( + this: &$name, + predicate: &mut dyn FnMut($type, $type, u32, $name) -> $type, + initial_value: &$type, + ) -> $type; + + #[wasm_bindgen(method, js_name = reduceRight)] + pub fn reduce_right( + this: &$name, + predicate: &mut dyn FnMut($type, $type, u32, $name) -> $type, + initial_value: &$type, + ) -> $type; + + #[wasm_bindgen(method)] + pub fn reverse(this: &$name) -> $name; + + #[wasm_bindgen(method)] + pub fn shift(this: &$name) -> $type; + + #[wasm_bindgen(method)] + pub fn slice(this: &$name, start: u32, end: u32) -> $name; + + #[wasm_bindgen(method)] + pub fn some(this: &$name, predicate: &mut dyn FnMut($type) -> bool) -> bool; + + #[wasm_bindgen(method)] + pub fn sort(this: &$name) -> $name; + + #[wasm_bindgen(method)] + pub fn splice(this: &$name, start: u32, delete_count: u32, item: &$type) -> $name; + + #[wasm_bindgen(method, js_name = toLocaleString)] + pub fn to_locale_string( + this: &$name, + locales: &$type, + options: &$type, + ) -> ::js_sys::JsString; + + #[wasm_bindgen(method, js_name = toString)] + pub fn to_string(this: &$name) -> ::js_sys::JsString; + + #[wasm_bindgen(method)] + pub fn unshift(this: &$name, value: &$type) -> u32; + } + + #[derive(Debug, Clone)] + pub struct ArrayIntoIter { + range: core::ops::Range, + array: $name, + } + + impl core::iter::Iterator for ArrayIntoIter { + type Item = $type; + + fn next(&mut self) -> Option { + let index = self.range.next()?; + self.array.get(index) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.range.size_hint() + } + + #[inline] + fn count(self) -> usize + where + Self: Sized, + { + self.range.count() + } + + #[inline] + fn last(self) -> Option + where + Self: Sized, + { + let Self { range, array } = self; + range.last().map(|index| array.get(index)).flatten() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.range + .nth(n) + .map(|index| self.array.get(index)) + .flatten() + } + } + + impl core::iter::DoubleEndedIterator for ArrayIntoIter { + fn next_back(&mut self) -> Option { + let index = self.range.next_back()?; + self.array.get(index) + } + + fn nth_back(&mut self, n: usize) -> Option { + self.range + .nth_back(n) + .map(|index| self.array.get(index)) + .flatten() + } + } + + impl core::iter::FusedIterator for ArrayIntoIter {} + + impl core::iter::ExactSizeIterator for ArrayIntoIter {} + + #[derive(Debug, Clone)] + pub struct ArrayIter<'a> { + range: core::ops::Range, + array: &'a $name, + } + + impl core::iter::Iterator for ArrayIter<'_> { + type Item = $type; + + fn next(&mut self) -> Option { + let index = self.range.next()?; + self.array.get(index) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.range.size_hint() + } + + #[inline] + fn count(self) -> usize + where + Self: Sized, + { + self.range.count() + } + + #[inline] + fn last(self) -> Option + where + Self: Sized, + { + let Self { range, array } = self; + range.last().map(|index| array.get(index)).flatten() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.range + .nth(n) + .map(|index| self.array.get(index)) + .flatten() + } + } + + impl core::iter::DoubleEndedIterator for ArrayIter<'_> { + fn next_back(&mut self) -> Option { + let index = self.range.next_back()?; + self.array.get(index) + } + + fn nth_back(&mut self, n: usize) -> Option { + self.range + .nth_back(n) + .map(|index| self.array.get(index)) + .flatten() + } + } + + impl core::iter::FusedIterator for ArrayIter<'_> {} + + impl core::iter::ExactSizeIterator for ArrayIter<'_> {} + + impl $name { + /// Returns an iterator over the values of the JS array. + pub fn iter(&self) -> ArrayIter<'_> { + ArrayIter { + range: 0..self.length(), + array: self, + } + } + + /// Converts the JS array into a new Vec. + pub fn to_vec(&self) -> Vec> { + let len = self.length(); + + let mut output = Vec::with_capacity(len as usize); + for i in 0..len { + output.push(self.get(i)); + } + + output + } + } + + impl core::iter::IntoIterator for $name { + type Item = $type; + type IntoIter = ArrayIntoIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIntoIter { + range: 0..self.length(), + array: self, + } + } + } + + // TODO pre-initialize the Array with the correct length using TrustedLen + impl core::iter::FromIterator<$type> for $name { + fn from_iter(iter: T) -> $name + where + T: IntoIterator, + { + let mut out = $name::new(); + out.extend(iter); + out + } + } + + impl core::iter::Extend<$type> for $name { + fn extend(&mut self, iter: T) + where + T: IntoIterator, + { + for value in iter { + self.push(value.as_ref()); + } + } + } + + impl TypedArray for $name { + type Item = $type; + + fn new_with_length(len: u32) -> $name { + Self::new_with_length(len) + } + + fn push(&self, item: &Self::Item) -> u32 { + self.push(&item) + } + } + } + }; +} diff --git a/worker/src/ai.rs b/worker/src/ai.rs index da2bdbdff..83f0240e7 100644 --- a/worker/src/ai.rs +++ b/worker/src/ai.rs @@ -1,11 +1,21 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{self, Poll}; + use crate::{env::EnvBinding, send::SendFuture}; use crate::{Error, Result}; -use serde::de::DeserializeOwned; -use serde::Serialize; +use futures_util::io::{BufReader, Lines}; +use futures_util::{ready, AsyncBufReadExt as _, Stream, StreamExt as _}; +use js_sys::Reflect; +use js_sys::JSON::parse; +use pin_project::pin_project; use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen_futures::JsFuture; +use wasm_streams::readable::IntoAsyncRead; use worker_sys::Ai as AiSys; +pub mod models; + /// Enables access to Workers AI functionality. #[derive(Debug)] pub struct Ai(AiSys); @@ -14,20 +24,24 @@ impl Ai { /// Execute a Workers AI operation using the specified model. /// Various forms of the input are documented in the Workers /// AI documentation. - pub async fn run( - &self, - model: impl AsRef, - input: T, - ) -> Result { - let fut = SendFuture::new(JsFuture::from( - self.0 - .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?), - )); + pub async fn run(&self, input: M::Input) -> Result { + let fut = SendFuture::new(JsFuture::from(self.0.run(M::MODEL_NAME, input.into()))); match fut.await { - Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?), + Ok(output) => Ok(output.into()), Err(err) => Err(Error::from(err)), } } + + pub async fn run_streaming(&self, input: M::Input) -> Result> { + let input = input.into(); + Reflect::set(&input, &JsValue::from_str("stream"), &JsValue::TRUE)?; + + let fut = SendFuture::new(JsFuture::from(self.0.run(M::MODEL_NAME, input))); + let raw_stream = fut.await?.dyn_into::()?; + let stream = wasm_streams::ReadableStream::from_raw(raw_stream).into_async_read(); + + Ok(AiStream::new(stream)) + } } unsafe impl Sync for Ai {} @@ -82,3 +96,75 @@ impl EnvBinding for Ai { } } } + +pub trait Model: 'static { + const MODEL_NAME: &str; + type Input: Into; + type Output: From; +} + +pub trait StreamableModel: Model {} + +#[derive(Debug)] +#[pin_project] +pub struct AiStream { + #[pin] + inner: Lines>>, + phantom: PhantomData, +} + +impl AiStream { + pub fn new(stream: IntoAsyncRead<'static>) -> Self { + Self { + inner: BufReader::new(stream).lines(), + phantom: PhantomData, + } + } +} + +impl Stream for AiStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut this = self.project(); + let string = match ready!(this.inner.poll_next_unpin(cx)) { + Some(item) => match item { + Ok(item) => { + if item.is_empty() { + match ready!(this.inner.poll_next_unpin(cx)) { + Some(item) => match item { + Ok(item) => item, + Err(err) => { + return Poll::Ready(Some(Err(err.into()))); + } + }, + None => { + return Poll::Ready(None); + } + } + } else { + item + } + } + Err(err) => { + return Poll::Ready(Some(Err(err.into()))); + } + }, + None => { + return Poll::Ready(None); + } + }; + + let string = if let Some(string) = string.strip_prefix("data: ") { + string + } else { + string.as_str() + }; + + if string == "[DONE]" { + return Poll::Ready(None); + } + + Poll::Ready(Some(Ok(parse(string)?.into()))) + } +} diff --git a/worker/src/ai/models.rs b/worker/src/ai/models.rs new file mode 100644 index 000000000..372239827 --- /dev/null +++ b/worker/src/ai/models.rs @@ -0,0 +1,55 @@ +pub mod llama_4_scout_17b_16e_instruct; + +pub mod scoped_chat { + use serde::Serialize; + + #[derive(Default, Debug, Serialize)] + #[serde(rename_all = "lowercase", untagged)] + pub enum Role { + #[default] + User, + Assistant, + System, + Tool, + Any(String), + } + + #[derive(Default, Debug, Serialize)] + pub struct RoleScopedChatInput { + pub role: Role, + pub content: String, + pub name: Option, + } + + pub fn user(content: &str) -> RoleScopedChatInput { + RoleScopedChatInput { + role: Role::User, + content: content.to_owned(), + name: None, + } + } + + pub fn assistant(content: &str) -> RoleScopedChatInput { + RoleScopedChatInput { + role: Role::Assistant, + content: content.to_owned(), + name: None, + } + } + + pub fn system(content: &str) -> RoleScopedChatInput { + RoleScopedChatInput { + role: Role::System, + content: content.to_owned(), + name: None, + } + } + + pub fn tool(content: &str) -> RoleScopedChatInput { + RoleScopedChatInput { + role: Role::Tool, + content: content.to_owned(), + name: None, + } + } +} diff --git a/worker/src/ai/models/llama_4_scout_17b_16e_instruct.rs b/worker/src/ai/models/llama_4_scout_17b_16e_instruct.rs new file mode 100644 index 000000000..6c5747571 --- /dev/null +++ b/worker/src/ai/models/llama_4_scout_17b_16e_instruct.rs @@ -0,0 +1,16 @@ +use worker_sys::{AiTextGenerationInput, AiTextGenerationOutput}; + +use crate::{Model, StreamableModel}; + +#[derive(Debug)] +pub struct Llama4Scout17b16eInstruct; + +impl Model for Llama4Scout17b16eInstruct { + const MODEL_NAME: &str = "@cf/meta/llama-4-scout-17b-16e-instruct"; + + type Input = AiTextGenerationInput; + + type Output = AiTextGenerationOutput; +} + +impl StreamableModel for Llama4Scout17b16eInstruct {}