-
Notifications
You must be signed in to change notification settings - Fork 65
feat: Lazy Spans and KV Blocks #249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Lazy Spans and KV Blocks #249
Conversation
And the way it fits into a model that uses apply_chat_tempalte or any other parser/renderer. Note that there's still a bug entailed by the chance that there are also substrings which "hit" on the cached contents. We don't anticipate this happens often in practice because of how KV cache smashing should typically be used, but it's something we need to address by introducing the use of sentinel values, or indexing string machines, or something else along those lines. no-verify commit because the point of this code is documentation.
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟢 Enforce conventional commitWonderful, this rule succeeded.Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/
|
mellea/backends/_utils.py
Outdated
| def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]: | ||
| """Returns the generation walk ordering for a Span.""" | ||
| match c: | ||
| case ModelOutputThunk() if not c.is_computed(): | ||
| return [c] | ||
| case CBlock(): | ||
| return [] | ||
| case Component(): | ||
| parts_walk = [generate_walk(p) for p in c.parts()] | ||
| return itertools.chain.from_iterable(parts_walk) # aka flatten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @jakelorocco
- we'll have to start doing this in the backend generate calls.
- This also means that we need to go back through stdlib and use
parts()correctly. (No action on your part atm) - [-] We probably want some sort of linting rule for third party code that warns the developer when they've got data in a
Componentclass which has typeCBlock | Componentbut which does not appear inparts(). - [-] I think we might want to make
ModelOutputThunkNOT be a subtype ofCBlockbecause Python pattern matching is first-match not most-specific-match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nrfulton, should we also add some sort of computed / non-computed flag to Components because they will now suffer a similar situation as ModelOutputThunks?
And is it up to the Component owner what happens when not all parts of a Component are computed? For example, with a ModelOutputThunk, it's value is None until it is fully computed. Should we specify a similar default behavior for components?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might want to make ModelOutputThunk NOT be a subtype of CBlock because Python pattern matching is first-match not most-specific-match.
I think that's fine. It's yet to be seen / fully implemented, but in the work for adding return types and parsing functions to Components, a CBlock is really just a Component with no parts (or one part?) that has a str return type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also add some sort of computed / non-computed flag to Components because they will now suffer a similar situation as ModelOutputThunks?
I need to think about this. It's not quite the same as ModelOutputThunks. And I think it can be a computed method rather than a flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we specify a similar default behavior for components?
We need to think about this. It's different from what happens with mots.
Things can go wrong. In particular: Component.format_for_llm should only be called when component prefillable judgement is derivable. But to your question regarding "similar behavior": format_for_llm can't ensure this contract holds itself because it doesn't have a backend in context (and shouldn't!).
NB: the problem isn't introduced by this PR, it already exists, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it already exists but doesn't manifest since we pretty much always use computed stuff right now.
docs/examples/melp/states.py
Outdated
| import mellea | ||
| from mellea.stdlib.base import CBlock, Context, SimpleContext | ||
| from mellea.stdlib.span import Span, SimpleComponent | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
| import asyncio | ||
|
|
||
|
|
||
| async def main(backend: Backend, ctx: Context): | ||
| a_states = "Alaska,Arizona,Arkansas".split(",") | ||
| m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" | ||
|
|
||
| a_state_pops = dict() | ||
| for state in a_states: | ||
| a_state_pops[state], _ = await backend.generate_from_context( | ||
| CBlock(f"What is the population of {state}? Respond with an integer only."), | ||
| SimpleContext(), | ||
| ) | ||
| a_total_pop = SimpleComponent( | ||
| instruction=CBlock( | ||
| "What is the total population of these states? Respond with an integer only." | ||
| ), | ||
| **a_state_pops, | ||
| ) | ||
| a_state_total, _ = await backend.generate_from_context(a_total_pop, SimpleContext()) | ||
|
|
||
| m_state_pops = dict() | ||
| for state in m_states: | ||
| m_state_pops[state], _ = await backend.generate_from_context( | ||
| CBlock(f"What is the population of {state}? Respond with an integer only."), | ||
| SimpleContext(), | ||
| ) | ||
| m_total_pop = SimpleComponent( | ||
| instruction=CBlock( | ||
| "What is the total population of these states? Respond with an integer only." | ||
| ), | ||
| **m_state_pops, | ||
| ) | ||
| m_state_total, _ = await backend.generate_from_context(m_total_pop, SimpleContext()) | ||
|
|
||
| print(await a_state_total.avalue()) | ||
| print(await m_state_total.avalue()) | ||
|
|
||
|
|
||
| backend = OllamaModelBackend(model_id="granite4:latest") | ||
| asyncio.run(main(backend, SimpleContext())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @HendrikStrobelt this is what lazy spans look like now.
Remember that await backend.generate_from_context doesn't actually await on the computation of the result. This merely awaits on the triggering on the generate call. So the full lifecycle of an call that looks sync has two awaits:
mot, new_ctx = await backend.generate_from_context(...)
result: str = await mot.avalue()It's not the prettiest code in the world, but it's nice to see that lazy spans still work after our long sojourn into devexp land.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that await backend.generate_from_context doesn't actually await on the computation of the result. This merely awaits on the triggering on the generate call.
Just wanted to call this out since python async is weird. Since backend.generate_from_context() can always do work immediately (ie processing the model opts / context, queueing up the API call, ...), Python should never actually pause the control flow at that await boundary. It will always immediately do the work to get you the ModelOutputThunk since none of the backends (currently) have await statements inside their backend.generate_from_context() functions that actually have to await asynchronous work being done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another gotcha: we should await .gather() rather than await if you are awaiting on multiple things. There's a bug in my version of the generate_walk:
_to_compute = generate_walk(action)
await asyncio.gather([x.avalue() for x in _to_compute])|
Related stuff coming out of today's standup:
|
|
Backend cleanup debt captured in #253 |
TODO-nrf: we need to add generate walks to every generation call.
Deletes the stdlib.span package and moves simplecomponent into base. Fixes a big in call to gather (should be *list not list)
Accepted the `nathan/conceptual_spans` side of this merge for huggingface.py. I'm now going to re-add that code in the next commit.
|
From a conversation with @jakelorocco : It is dangerous for us to have both We should make it easy for Component developers to say how a component should be represented to an LLM without understanding the core, but also in a way that does not violate invariants about A major "gotcha" is that Component developers might rip out strings from CBlocks or ModelOutputThunks within their format_for_llm call. This would violate our We can guard against this in two ways:
Then, the The |
Mots are fully computed when generation is done (even if they have an unresolved tool request). It's up to the user or sampling strategy to decide whether to actually call that tool. We can flag that better. But the result of the tool call shouldn't necessarily be a part of the mot. It's its own object that must be passed back to the model as a separate message if desired. |
jakelorocco
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm; ran all tests
* Adds cache smash code from the Project M codebase. * rename to avoid clash b/w cache/ and cache.py * Adds cache flag to CBlock. * Initial work on re-introducing span-ish KV caching. no-verify. * Adds a crystallization of the kv smash code And the way it fits into a model that uses apply_chat_tempalte or any other parser/renderer. Note that there's still a bug entailed by the chance that there are also substrings which "hit" on the cached contents. We don't anticipate this happens often in practice because of how KV cache smashing should typically be used, but it's something we need to address by introducing the use of sentinel values, or indexing string machines, or something else along those lines. no-verify commit because the point of this code is documentation. * Adds KV cache smash. * Adds example of kv cache smash. * Adds a SimpleComponent. * Adds a simple lazy example. * ollama generate walk. TODO-nrf: we need to add generate walks to every generation call. * Does gather() instead of awaiting on each thunk separately. * Refactor and bug fixes. Deletes the stdlib.span package and moves simplecomponent into base. Fixes a big in call to gather (should be *list not list) * backend walks. * Adds heapcomponents. * Make uncomputed mots logging less noisy. * adds a simple example. * Cleans up fib example. * Adds parts() for instruction and genslot components. * Don't call things components which are not components. * ruff. * Starts adding some examples for a deepdive on sessions. * blah * blah * Add parts() to chat. * Fixes GenerativeSlot.parts() * Confirm assumption that RichDocument has no parts() for now. * Define parts() on TableQuery * Fixes ruff errors. * Fixes error in HeapContext.add caught by mypy. * Fixes mypy errors caused by shadowing * Adds parts() definitions to the rest of the RichDocument components. These need a substantial cleanup and refactor with greater attention to detail. * fixes Instruction.parts() * Improves warning message for Intrinsic.parts() * update comment on mify.parts() * parts() implementations for MObject components. * parts() implementation for Requirements. * Some notes about the deep dives. * Fixes line noise in previous commit. * Finish resolving merge. * Examples are working (for some value of working -- results are garbage. * precommit hooks are passing. * Small changes to hf kv smash example. * Fix fib example. * Remove accidental commit. * Removes unnecessary print statements. * Removes HeapContext. * Intrinsics cannot surface parts because they always rewrite history anyways. * removes dead helper code. * removed code clone. * adds test. * Adds type:ignore because mypy 1.19.1 is buggy. * fixes bug in GenerativeSlot.parts() * adds missing arg in span tests. * See generative-computing#258 * fixes failing tests. --------- Co-authored-by: Avinash Balakrishnan <[email protected]>
Introduction to Mellea's Spans
A Span is a contiguous piece of text that defines a tokenization and KV cache boundary within a context.
Spans play two roles.
Spans delineate conceptually/semantically related content. Examples of Spans in this sense include: RAG documents, chunks of RAG docuemnts, encoded images, other artifacts (such as code, execution traces, error logs), and chat messages.
Most spans also define sensible KV boundaries modulo positional encodings. For example: we can pre-compute KV cache for all of the documents in a RAG database and then re-use those as prefixes. So we have the KV blocks associated with each document, and each of those KV blocks corresponds to a conceptually whole entity (the document). This is why we include the words "tokenization boundary" in the definition of the term Span.
It is useful to distinguish between the two roles that a Span plays when discussing implementation details. We refer to all the KV caching semantics as "KV blocks" and we refer the conceptual grouping semantics as "conceptual spans".
When we say "span" nakedly, we mean something that is both a "conceptual span" and also a "kv block"; i.e., a span is is a contiguous piece of text that defines a tokenization and KV cache boundary within a context
This PR is about both. It started with conceptual spans and focused on re-introducing these "conceptual spans" into mellea from one of our earlier experimental code bases. We then merged in the corresponding PR, now closed, on the KV span / KV block aspect (#111). The two PRs are now merged together.
Lazy Span Implementation Details
The Mellea tutorial uses the stdlib
MelleaSessionandmfuncabstractions to hide Mellea's core from the user. In this section we peel back the Session and mfunc abstracions so that we can see how Mellea works under the hood.Mellea represents data using three types:
Component | CBlock | ModelOutputThunk.CBlocksare a wrapper around inputs to an LLM.ModelOutputThunks are outputs from LLMs. These are created prior to any LLM call actually happening.Components are composite types that implement a protocol that explains how the Component should be represented to an LLM.Let's review each of these.
CBlocks and Thunks
CBlocks(andComponents) are passed into a model via aBackend. TheBackendemits aModelOutputThunk(with a newContextwhich we will talk about in a moent). For example,Notice how a ModelOutputThunk can be uncomputed (
mot.value is None) or computedmot.avalue is not None.Important
We need to think about intermediate MoT states, such as where a mot has been cmoputed but has a tool call that is pending.
Components
Components can be composed of both
CBlocks andModelOutputThunks. For example,Let's extend this component a bit so that we can print it out and see which of its thunks are computed:
(Aside: Recall in the first example we had to
awaitthe value ofout_0before computingnext_int.One of the things we need to change is automatic awaiting on MoTs that are contituents of Components as part of the generate call. This existed in our first couple codebases and we need to add thatb ack here.)
Notably,
Components can be constructed usingModelOutputThunks that are not yet computed. So, in our core data structure we have a data dependency graph. E.g.,KV Blocks
Each
CBlock | Componentadditionally corresponds to a tokenization boundary and associated KV cache. These KV caches can be mashed together to allow for cache reuse beyond prefix-based reuse. This is currently implemented in the huggingface backend.Remaining Todos
kv block stuff
Won't-do-for-now
stdlib fixes
Define
parts()for existing components:SamplingResult