Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 124 additions & 76 deletions contextwindow.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@
//
// # Abbreviated usage
//
// model, err := NewOpenAIResponsesModel(shared.ResponsesModel4o)
// if err != nil {
// log.Fatalf("Failed to create model: %v", err)
// }
// model, err := NewOpenAIResponsesModel(shared.ResponsesModel4o)
// if err != nil {
// log.Fatalf("Failed to create model: %v", err)
// }
//
// cw, err := contextwindow.New(model, nil, "")
// if err != nil {
// log.Fatalf("Failed to create context window: %v", err)
// }
// defer cw.Close()
// cw, err := contextwindow.New(model, nil, "")
// if err != nil {
// log.Fatalf("Failed to create context window: %v", err)
// }
// defer cw.Close()
//
// if err := cw.AddPrompt(ctx, "how's the weather?"); err != nil {
// log.Fatalf("Failed to add prompt: %v", err)
// }
// if err := cw.AddPrompt(ctx, "how's the weather?"); err != nil {
// log.Fatalf("Failed to add prompt: %v", err)
// }
//
// response, err := cw.CallModel(ctx)
// if err != nil {
// log.Fatalf("Failed to call model: %v", err)
// }
// response, err := cw.CallModel(ctx)
// if err != nil {
// log.Fatalf("Failed to call model: %v", err)
// }
//
// fmt.Printf("response: %s\n", response)
// fmt.Printf("response: %s\n", response)
//
// # System prompts
//
Expand All @@ -35,19 +35,19 @@
//
// Instruct LLMs to call tools locally with [ContextWindow.AddTool] (and [NewTool]).
//
// lsTool := contextwindow.NewTool("list_files", `
// This tool lists files in the specified directory.
// `).AddStringParameter("directory", "Directory to list", true)
// lsTool := contextwindow.NewTool("list_files", `
// This tool lists files in the specified directory.
// `).AddStringParameter("directory", "Directory to list", true)
//
// cw.AddTool(lsTool, contextwindow.ToolRunnerFunc(func context.Context,
// args json.RawMessage) (string, error) {
// var treq struct {
// Dir string `json:"directory"`
// }
// json.Unmarshal(args, &treq)
// // actually run ls, or pretend to
// return "here\nare\nsome\nfiles.exe\n", nil
// })
// cw.AddTool(lsTool, contextwindow.ToolRunnerFunc(func context.Context,
// args json.RawMessage) (string, error) {
// var treq struct {
// Dir string `json:"directory"`
// }
// json.Unmarshal(args, &treq)
// // actually run ls, or pretend to
// return "here\nare\nsome\nfiles.exe\n", nil
// })
//
// You can selectively enable and disable tools with [ContextWindow.CallModelWithOpts].
//
Expand All @@ -62,12 +62,12 @@
//
// You can provide a summarizer model to automatically compact your context window:
//
// summarizerModel, err := openai.New(apiKey, "gpt-3.5-turbo")
// if err != nil {
// log.Fatalf("Failed to create summarizer: %v", err)
// }
// summarizerModel, err := openai.New(apiKey, "gpt-3.5-turbo")
// if err != nil {
// log.Fatalf("Failed to create summarizer: %v", err)
// }
//
// cw, err := contextwindow.New(model, summarizerModel, "")
// cw, err := contextwindow.New(model, summarizerModel, "")
//
// And then "compress" your context with [ContextWindow.SummarizeLiveContent].
//
Expand All @@ -81,19 +81,33 @@
// LLM conversations are stored in SQLite. If you don't care about persistant
// storage for your context, just specify ":memory:" as your database path.
//
// # Threading and Fallback Behavior
//
// When server-side threading is enabled, the library attempts to use
// response_id-based threading for efficiency. However, several conditions
// can cause automatic fallback to client-side threading:
//
// The response_id chain is broken or invalid
// Tool calls are present (they break server-side threading)
// The model's threading API call fails
// - The context has no previous response_id (first call)
//
// back is automatic and transparent - conversations continue normally
// g full message history. Check logs for threading decisions.
//
// # Thread Safety
//
// ContextWindow write operations (AddPrompt, SwitchContext, SetMaxTokens, etc.)
// require external coordination when used concurrently. However, you can use
// ContextWindow.Reader() to get a thread-safe read-only view:
//
// reader := cw.Reader()
// go updateUI(reader) // safe for concurrent use
// go updateMetrics(reader) // safe for concurrent use
// reader := cw.Reader()
// go updateUI(reader) // safe for concurrent use
// go updateMetrics(reader) // safe for concurrent use
//
// // Meanwhile, main thread can safely modify state:
// cw.SwitchContext("new-context")
// cw.SetMaxTokens(8192)
// // Meanwhile, main thread can safely modify state:
// cw.SwitchContext("new-context")
// cw.SetMaxTokens(8192)
//
// ContextReader provides access to read operations like LiveRecords(), TokenUsage(),
// and context querying, all of which are safe for concurrent use.
Expand Down Expand Up @@ -310,12 +324,13 @@ func (cw *ContextWindow) AddToolOutput(output string) error {

// SetRecordLiveStateByRange updates the live status of records in the specified range.
// Indices are based on the current LiveRecords() slice, with both start and end inclusive.
// This allows selective marking of context elements as active (live=true) or
// This allows selective marking of context elements as active (live=true) or
// inactive (live=false) based on their position in the conversation.
//
// Examples:
// SetRecordLiveStateByRange(2, 4, false) // marks records at indices 2, 3, 4 as dead
// SetRecordLiveStateByRange(5, 5, false) // marks only record at index 5 as dead
//
// SetRecordLiveStateByRange(2, 4, false) // marks records at indices 2, 3, 4 as dead
// SetRecordLiveStateByRange(5, 5, false) // marks only record at index 5 as dead
func (cw *ContextWindow) SetRecordLiveStateByRange(startIndex, endIndex int, live bool) error {
if startIndex < 0 || endIndex < startIndex {
return fmt.Errorf("invalid range: startIndex=%d, endIndex=%d", startIndex, endIndex)
Expand Down Expand Up @@ -417,6 +432,32 @@ func (cw *ContextWindow) CallModel(ctx context.Context) (string, error) {
return cw.CallModelWithOpts(ctx, CallModelOpts{})
}

func (cw *ContextWindow) shouldAttemptServerSideThreading(
ci Context,
recs []Record,
) (should bool, reason string /* not using this yet but seems like a good idea */) {

if !ci.UseServerSideThreading {
return false, "server-side threading not enabled for context"
}

_, ok := cw.model.(ServerSideThreadingCapable)
if !ok {
return false, "model does not support server-side threading"
}

if ci.LastResponseID == nil || *ci.LastResponseID == "" {
return false, "no last_response_id available (first call or chain broken)"
}

valid, reason := ValidateResponseIDChain(cw.db, ci)
if !valid {
return false, fmt.Sprintf("response_id chain invalid: %s", reason)
}

return true, "preconditions met"
}

// CallModelWithOpts drives an LLM with options. It composes live messages, invokes cw.model.Call,
// logs the response, updates token count, and triggers compaction.
func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOpts) (string, error) {
Expand All @@ -436,48 +477,56 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp
return "", fmt.Errorf("list live records: %w", err)
}

var events []Record
var tokensUsed int
var responseID *string

// Serverside threading (`previous_response_id`) sends only the most recent prompt
// and a backlink to the last response, rather than sending the entire thread on
// every LLM call.
// TODO(tqbf): this stuff needs better testing; I don't really use it.
if contextInfo.UseServerSideThreading {
if threadingModel, ok := cw.model.(ServerSideThreadingCapable); ok {
if optsModel, ok := threadingModel.(CallOptsCapable); ok {
events, responseID, tokensUsed, err = optsModel.CallWithThreadingAndOpts(
ctx,
true,
contextInfo.LastResponseID,
recs,
opts,
)
} else {
events, responseID, tokensUsed, err = threadingModel.CallWithThreading(
ctx,
true,
contextInfo.LastResponseID,
recs,
)
}
if err != nil {
return "", fmt.Errorf("call model with threading: %w", err)
}
var (
events []Record
tokensUsed int
responseID *string
)

attemptServerSide, _ := cw.shouldAttemptServerSideThreading(contextInfo, recs)

if attemptServerSide {
threadingModel := cw.model.(ServerSideThreadingCapable)
var err error

if optsModel, ok := threadingModel.(CallOptsCapable); ok {
events, responseID, tokensUsed, err = optsModel.CallWithThreadingAndOpts(
ctx,
true,
contextInfo.LastResponseID,
recs,
opts,
)
} else {
return "", fmt.Errorf("model does not support server-side threading")
events, responseID, tokensUsed, err = threadingModel.CallWithThreading(
ctx,
true,
contextInfo.LastResponseID,
recs,
)
}

if err != nil {
// Fall through to client-side threading
attemptServerSide = false
}
} else {
// Fall back to traditional client-side threading
}

// Use client-side threading (either as fallback or default)
if !attemptServerSide {
if optsModel, ok := cw.model.(CallOptsCapable); ok {
events, tokensUsed, err = optsModel.CallWithOpts(ctx, recs, opts)
} else {
events, tokensUsed, err = cw.model.Call(ctx, recs)
}
if err != nil {
if contextInfo.UseServerSideThreading {
return "", fmt.Errorf("call model (fallback to client-side threading): %w", err)
}
return "", fmt.Errorf("call model: %w", err)
}
// Client-side threading doesn't return responseID
responseID = nil
}

cw.metrics.Add(tokensUsed)
Expand All @@ -497,7 +546,6 @@ func (cw *ContextWindow) CallModelWithOpts(ctx context.Context, opts CallModelOp
lastMsg = event.Content
}

// Update the context's last response ID if we got one
if responseID != nil {
err = UpdateContextLastResponseID(cw.db, contextID, *responseID)
if err != nil {
Expand Down
Loading