-
Notifications
You must be signed in to change notification settings - Fork 354
feat: create trait definitions for model and streamable model #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
8b9022f
d61679d
3607ee4
ac6f9fa
0a1989f
442eff7
5654e9a
baecae1
4080689
85d2b86
1ce8ff3
5b5310f
af555fb
409b221
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| use serde::{Deserialize, Serialize}; | ||
| use worker::{Env, Model, Request, Response, Result, StreamableModel}; | ||
|
|
||
| use crate::SomeSharedData; | ||
|
|
||
| pub struct Llama4Scout17b16eInstruct; | ||
|
|
||
| #[derive(Serialize)] | ||
| pub struct DefaultTextGenerationInput { | ||
| pub prompt: String, | ||
| } | ||
|
|
||
| #[derive(Deserialize)] | ||
| pub struct DefaultTextGenerationOutput { | ||
| pub response: String, | ||
| } | ||
|
|
||
| impl From<DefaultTextGenerationOutput> for Vec<u8> { | ||
| fn from(value: DefaultTextGenerationOutput) -> Self { | ||
| value.response.into_bytes() | ||
| } | ||
| } | ||
|
|
||
| impl Model for Llama4Scout17b16eInstruct { | ||
| const MODEL_NAME: &str = "@cf/meta/llama-4-scout-17b-16e-instruct"; | ||
|
|
||
| type Input = DefaultTextGenerationInput; | ||
|
|
||
| type Output = DefaultTextGenerationOutput; | ||
| } | ||
|
|
||
| impl StreamableModel for Llama4Scout17b16eInstruct {} | ||
|
|
||
| const AI_TEST: &str = "AI_TEST"; | ||
|
|
||
| #[worker::send] | ||
| pub async fn simple_ai_text_generation( | ||
| _: Request, | ||
| env: Env, | ||
| _data: SomeSharedData, | ||
| ) -> Result<Response> { | ||
| let ai = env | ||
| .ai(AI_TEST)? | ||
| .run::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput { | ||
| prompt: "What is the answer to life the universe and everything?".to_owned(), | ||
| }) | ||
guybedford marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .await?; | ||
| Response::ok(ai.response) | ||
| } | ||
|
|
||
| #[worker::send] | ||
| pub async fn streaming_ai_text_generation( | ||
| _: Request, | ||
| env: Env, | ||
| _data: SomeSharedData, | ||
| ) -> Result<Response> { | ||
| let stream = env | ||
| .ai(AI_TEST)? | ||
| .run_streaming::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput { | ||
| prompt: "What is the answer to life the universe and everything?".to_owned(), | ||
| }) | ||
| .await?; | ||
|
|
||
| Response::from_stream(stream) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| import { describe, expect, test } from "vitest"; | ||
| import { mf, mfUrl } from "./mf"; | ||
|
|
||
| async function runTest() { | ||
| let normal_response = await mf.dispatchFetch(`${mfUrl}/ai`); | ||
| expect(normal_response.status).toBe(200); | ||
|
|
||
| let streaming_response = await mf.dispatchFetch(`${mfUrl}/ai/streaming`); | ||
| expect(streaming_response.status).toBe(200); | ||
| } | ||
| describe("ai", runTest); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,17 @@ | ||
| use std::marker::PhantomData; | ||
| use std::pin::Pin; | ||
| use std::task::{self, Poll}; | ||
|
|
||
| use crate::{env::EnvBinding, send::SendFuture}; | ||
| use crate::{Error, Result}; | ||
| use serde::de::DeserializeOwned; | ||
| use serde::Serialize; | ||
| use futures_util::io::{BufReader, Lines}; | ||
| use futures_util::{ready, AsyncBufReadExt as _, Stream, StreamExt as _}; | ||
| use js_sys::Reflect; | ||
| use pin_project::pin_project; | ||
| use serde::{de::DeserializeOwned, Serialize}; | ||
| use wasm_bindgen::{JsCast, JsValue}; | ||
| use wasm_bindgen_futures::JsFuture; | ||
| use wasm_streams::readable::IntoAsyncRead; | ||
| use worker_sys::Ai as AiSys; | ||
|
|
||
| /// Enables access to Workers AI functionality. | ||
|
|
@@ -14,20 +22,27 @@ impl Ai { | |
| /// Execute a Workers AI operation using the specified model. | ||
| /// Various forms of the input are documented in the Workers | ||
| /// AI documentation. | ||
| pub async fn run<T: Serialize, U: DeserializeOwned>( | ||
| &self, | ||
| model: impl AsRef<str>, | ||
| input: T, | ||
| ) -> Result<U> { | ||
| pub async fn run<M: Model>(&self, input: M::Input) -> Result<M::Output> { | ||
| let fut = SendFuture::new(JsFuture::from( | ||
| self.0 | ||
| .run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?), | ||
| .run(M::MODEL_NAME, serde_wasm_bindgen::to_value(&input)?), | ||
| )); | ||
| match fut.await { | ||
| Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?), | ||
| Err(err) => Err(Error::from(err)), | ||
| } | ||
| } | ||
|
|
||
| pub async fn run_streaming<M: StreamableModel>(&self, input: M::Input) -> Result<AiStream<M>> { | ||
| let input = serde_wasm_bindgen::to_value(&input)?; | ||
| Reflect::set(&input, &JsValue::from_str("stream"), &JsValue::TRUE)?; | ||
|
|
||
| let fut = SendFuture::new(JsFuture::from(self.0.run(M::MODEL_NAME, input))); | ||
| let raw_stream = fut.await?.dyn_into::<web_sys::ReadableStream>()?; | ||
| let stream = wasm_streams::ReadableStream::from_raw(raw_stream).into_async_read(); | ||
|
|
||
| Ok(AiStream::new(stream)) | ||
| } | ||
| } | ||
|
|
||
| unsafe impl Sync for Ai {} | ||
|
|
@@ -82,3 +97,75 @@ impl EnvBinding for Ai { | |
| } | ||
| } | ||
| } | ||
|
|
||
| pub trait Model: 'static { | ||
| const MODEL_NAME: &str; | ||
| type Input: Serialize; | ||
| type Output: DeserializeOwned; | ||
| } | ||
|
|
||
| pub trait StreamableModel: Model {} | ||
|
|
||
| #[derive(Debug)] | ||
| #[pin_project] | ||
| pub struct AiStream<T: StreamableModel> { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I said I'm still not clear why AiStream is unique from a generic stream. Couldn't we just have TypedReadeableStream like you did for array?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep we can I'll do what I said previously here and type alias it so we can document the output
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll work on this tommorow |
||
| #[pin] | ||
| inner: Lines<BufReader<IntoAsyncRead<'static>>>, | ||
| phantom: PhantomData<T>, | ||
| } | ||
|
|
||
| impl<T: StreamableModel> AiStream<T> { | ||
| pub fn new(stream: IntoAsyncRead<'static>) -> Self { | ||
| Self { | ||
| inner: BufReader::new(stream).lines(), | ||
| phantom: PhantomData, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl<T: StreamableModel> Stream for AiStream<T> { | ||
| type Item = Result<T::Output>; | ||
|
|
||
| fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { | ||
| let mut this = self.project(); | ||
| let string = match ready!(this.inner.poll_next_unpin(cx)) { | ||
| Some(item) => match item { | ||
| Ok(item) => { | ||
| if item.is_empty() { | ||
| match ready!(this.inner.poll_next_unpin(cx)) { | ||
| Some(item) => match item { | ||
| Ok(item) => item, | ||
| Err(err) => { | ||
| return Poll::Ready(Some(Err(err.into()))); | ||
| } | ||
| }, | ||
| None => { | ||
| return Poll::Ready(None); | ||
| } | ||
| } | ||
| } else { | ||
| item | ||
| } | ||
| } | ||
| Err(err) => { | ||
| return Poll::Ready(Some(Err(err.into()))); | ||
| } | ||
| }, | ||
| None => { | ||
| return Poll::Ready(None); | ||
| } | ||
| }; | ||
|
|
||
| let string = if let Some(string) = string.strip_prefix("data: ") { | ||
| string | ||
| } else { | ||
| string.as_str() | ||
| }; | ||
|
|
||
| if string == "[DONE]" { | ||
| return Poll::Ready(None); | ||
| } | ||
|
|
||
| Poll::Ready(Some(Ok(serde_json::from_str(string)?))) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed this seems like the right approach!