diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml
index 47a90a7..cf8f30d 100644
--- a/.github/workflows/integration_test.yml
+++ b/.github/workflows/integration_test.yml
@@ -123,12 +123,21 @@ jobs:
pipx install poetry
poetry install
+ - name: Cache llm2 models
+ uses: actions/cache/restore@v5
+ id: cache-llm2-models-restore
+ env:
+ cache-name: cache-llm2-models
+ with:
+ path: llm2-persistent_storage/
+ key: ${{ runner.os }}-llm2-models-${{ env.cache-name }}-${{ hashFiles('llm2/lib/main.py') }}
+
- name: Install and init backend
working-directory: ${{ env.APP_NAME }}/lib
env:
APP_VERSION: ${{ fromJson(steps.appinfo.outputs.result).version }}
run: |
- poetry run python3 main.py > ../backend_logs 2>&1 &
+ APP_PERSISTENT_STORAGE="$(pwd)/../../llm2-persistent_storage/" poetry run python3 main.py > ../backend_logs 2>&1 &
- name: Register backend
run: |
@@ -156,6 +165,43 @@ jobs:
curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json
[ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ]
+ - name: Cache llm2 models
+ uses: actions/cache/save@v5
+ env:
+ cache-name: cache-llm2-models
+ with:
+ path: llm2-persistent_storage/
+ key: ${{ steps.cache-llm2-models-restore.outputs.cache-primary-key }}
+
+ - name: Run streaming task
+ if: matrix.server-versions == 'master'
+ env:
+ CREDS: "admin:password"
+ run: |
+ set -x
+ TASK=$(curl -X POST -u "$CREDS" -H "oCS-APIRequest: true" -H "Content-type: application/json" http://localhost:8080/ocs/v2.php/taskprocessing/schedule?format=json --data-raw '{"input": {"input": "Count from 1 to 20 in words"},"type":"core:text2text", "appId": "test", "customId": "", "preferStreaming": true}')
+ echo $TASK
+ TASK_ID=$(echo $TASK | jq '.ocs.data.task.id')
+ NEXT_WAIT_TIME=0
+ TASK_STATUS='"STATUS_SCHEDULED"'
+ STREAMING_UPDATES=0
+ until [ $NEXT_WAIT_TIME -eq 35 ] || [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ] || [ "$TASK_STATUS" == '"STATUS_FAILED"' ]; do
+ TASK=$(curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json)
+ echo $TASK
+ TASK_STATUS=$(echo $TASK | jq '.ocs.data.task.status')
+ echo $TASK_STATUS
+ TASK_OUTPUT=$(echo $TASK | jq -r '.ocs.data.task.output.output // ""')
+ if [ -n "$TASK_OUTPUT" ] && [ "$TASK_STATUS" != '"STATUS_SUCCESSFUL"' ] && [ "$TASK_STATUS" != '"STATUS_FAILED"' ]; then
+ STREAMING_UPDATES=$((STREAMING_UPDATES+1))
+ echo "Streaming update detected (count: $STREAMING_UPDATES)"
+ fi
+ sleep $(( NEXT_WAIT_TIME++ ))
+ done
+ echo "Final status: $TASK_STATUS"
+ echo "Total streaming updates detected: $STREAMING_UPDATES"
+ [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ]
+ [ $STREAMING_UPDATES -gt 0 ]
+
- name: Show logs
if: always()
run: |
diff --git a/default_config/config.json b/default_config/config.json
index 9265221..9850a28 100644
--- a/default_config/config.json
+++ b/default_config/config.json
@@ -61,7 +61,7 @@
"Qwen3.5-9B-Q4_K_M": {
"prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n",
"loader_config": {
- "n_ctx": 16384,
+ "n_ctx": 24000,
"max_tokens": 8192,
"stop": ["<|eot_id|>"],
"temperature": 0.7
diff --git a/lib/change_tone.py b/lib/change_tone.py
index 317fb32..cb2e771 100644
--- a/lib/change_tone.py
+++ b/lib/change_tone.py
@@ -10,6 +10,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class ChangeToneProcessor:
runnable: Runnable
@@ -33,10 +35,10 @@ class ChangeToneProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, input_data: dict) -> dict[str, Any]:
+ def __call__(self, input_data: dict, context: StreamContext | None = None) -> dict[str, Any]:
"""Process a single input"""
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format_prompt(text=input_data['input'], tone=input_data['tone']).to_string())
]
- return {'output':self.runnable.invoke(messages).content }
\ No newline at end of file
+ return {'output': run_runnable_with_streaming(self.runnable, messages, context)}
diff --git a/lib/chat.py b/lib/chat.py
index f848b45..d2c2815 100644
--- a/lib/chat.py
+++ b/lib/chat.py
@@ -3,13 +3,12 @@
"""A chat chain
"""
import json
-from typing import Any, Optional
+from typing import Any
-from langchain.callbacks.manager import CallbackManagerForChainRun
-from langchain.chains.base import Chain
-from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class ChatProcessor:
"""
@@ -24,10 +23,19 @@ def __init__(self, runner: Runnable):
def __call__(
self,
inputs: dict[str, Any],
+ context: StreamContext | None = None,
) -> dict[str, str]:
system_prompt = inputs['system_prompt']
if inputs.get('memories'):
system_prompt += "\n\nYou can remember things from other conversations with the user. If they are relevant, take into account the following memories: \n" + "\n\n".join(inputs['memories']) + "\n\n"
- return {'output': self.runnable.invoke(
- [('human', system_prompt)] + [(message['role'], message['content']) for message in [json.loads(message) for message in inputs['history']]] + [('human', inputs['input'])]
- ).content}
\ No newline at end of file
+ messages = [('human', system_prompt)] + [
+ (message['role'], message['content'])
+ for message in [json.loads(message) for message in inputs['history']]
+ ] + [('human', inputs['input'])]
+ return {
+ 'output': run_runnable_with_streaming(
+ self.runnable,
+ messages,
+ context,
+ )
+ }
diff --git a/lib/chatwithtools.py b/lib/chatwithtools.py
index 7962162..461b501 100644
--- a/lib/chatwithtools.py
+++ b/lib/chatwithtools.py
@@ -3,21 +3,34 @@
"""A chat chain
"""
import json
+import hashlib
import pprint
import re
-from random import randint
from typing import Any
from langchain_community.chat_models import ChatLlamaCpp
-from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
+from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages.ai import AIMessage
+from streaming import StreamContext, run_runnable_with_streaming
+
def generate_tool_call(tool_call: dict):
content = ''
content += json.dumps({"name": tool_call['name'], "arguments": tool_call['args']})
content += ''
return content
+
+def generate_tool_call_id(tool_call: dict) -> str:
+ stable_payload = json.dumps(
+ {
+ "name": tool_call.get("name"),
+ "args": tool_call.get("args", tool_call.get("arguments", {})),
+ },
+ sort_keys=True,
+ )
+ return hashlib.sha1(stable_payload.encode("utf-8")).hexdigest()[:16]
+
def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
@@ -40,7 +53,9 @@ def try_parse_tool_calls(content: str):
func['args'] = func['arguments']
del func['arguments']
if not 'id' in func:
- func['id'] = str(randint(1, 10000000000))
+ func['id'] = generate_tool_call_id(func)
+ if 'type' not in func:
+ func['type'] = 'tool_call'
found = True
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
@@ -66,7 +81,9 @@ def try_parse_tool_calls(content: str):
func['args'] = func['arguments']
del func['arguments']
if not 'id' in func:
- func['id'] = str(randint(1, 10000000000))
+ func['id'] = generate_tool_call_id(func)
+ if 'type' not in func:
+ func['type'] = 'tool_call'
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
@@ -79,6 +96,31 @@ def try_parse_tool_calls(content: str):
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
+
+def strip_tool_calls_for_streaming(content: str) -> str:
+ sanitized = re.sub(r".*?", "", content, flags=re.DOTALL)
+ sanitized = re.sub(r"```tool_call\n.*?\n```", "", sanitized, flags=re.DOTALL)
+
+ partial_markers = [index for index in (sanitized.find("$", "", sanitized)
+
+
+def build_streaming_payload(content: str) -> dict[str, Any] | None:
+ payload: dict[str, Any] = {}
+ cleaned_output = strip_tool_calls_for_streaming(content)
+ parsed_response = try_parse_tool_calls(content)
+ tool_calls = parsed_response.get('tool_calls')
+
+ if cleaned_output or tool_calls:
+ payload['output'] = cleaned_output
+ if tool_calls:
+ payload['tool_calls'] = json.dumps(tool_calls)
+
+ return payload or None
+
class ChatWithToolsProcessor:
"""
A chat with tools processor that supports batch processing
@@ -89,7 +131,7 @@ class ChatWithToolsProcessor:
def __init__(self, runner: ChatLlamaCpp):
self.model = runner
- def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
+ def _process_single_input(self, input_data: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
system_prompt = """
{downstream_system_prompt}
@@ -150,15 +192,20 @@ def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
messages.append(HumanMessage(content=''))
pprint.pprint(messages)
- response = self.model.invoke(messages)
+ response_content = run_runnable_with_streaming(
+ self.model,
+ messages,
+ context,
+ stream_payload_transform=build_streaming_payload,
+ suppress_empty_stream_updates=True,
+ )
- #if not response.tool_calls or len(response.tool_calls) == 0:
- response = AIMessage(**try_parse_tool_calls(response.content))
+ response = AIMessage(**try_parse_tool_calls(response_content))
return {
'output': response.content,
'tool_calls': json.dumps(response.tool_calls)
}
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
- return self._process_single_input(inputs)
\ No newline at end of file
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
+ return self._process_single_input(inputs, context)
diff --git a/lib/contextwrite.py b/lib/contextwrite.py
index 8e8ef08..8fe62fc 100644
--- a/lib/contextwrite.py
+++ b/lib/contextwrite.py
@@ -10,6 +10,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class ContextWriteProcessor:
runnable: Runnable
@@ -36,7 +38,7 @@ class ContextWriteProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
@@ -44,5 +46,5 @@ def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
source_input=inputs['source_input']
))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
\ No newline at end of file
+ output = run_runnable_with_streaming(self.runnable, messages, context)
+ return {'output': output}
diff --git a/lib/free_prompt.py b/lib/free_prompt.py
index 4f2cdd2..cb5ca25 100644
--- a/lib/free_prompt.py
+++ b/lib/free_prompt.py
@@ -1,10 +1,12 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
-from typing import Any, List
+from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class FreePromptProcessor:
"""
@@ -20,9 +22,10 @@ def __init__(self, runnable: Runnable):
def __call__(
self,
inputs: dict[str, Any],
+ context: StreamContext | None = None,
) -> dict[str, Any]:
- output = self.runnable.invoke([
+ output = run_runnable_with_streaming(self.runnable, [
SystemMessage(self.system_prompt),
HumanMessage(inputs['input'])
- ]).content
+ ], context)
return {'output': output}
\ No newline at end of file
diff --git a/lib/headline.py b/lib/headline.py
index dc8a6d7..ac88832 100644
--- a/lib/headline.py
+++ b/lib/headline.py
@@ -9,6 +9,8 @@
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class HeadlineProcessor:
"""
@@ -33,12 +35,12 @@ class HeadlineProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
text=inputs['input']
))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
+ output = run_runnable_with_streaming(self.runnable, messages, context)
+ return {'output': output}
diff --git a/lib/main.py b/lib/main.py
index 65c9279..652b200 100644
--- a/lib/main.py
+++ b/lib/main.py
@@ -13,6 +13,7 @@
from time import perf_counter, sleep, strftime
from niquests import RequestException
+from streaming import StreamContext
from task_processors import generate_task_processors
from fastapi import FastAPI
from nc_py_api import AsyncNextcloudApp, NextcloudApp, NextcloudException
@@ -24,6 +25,35 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
+
+class NextcloudTaskStreamResult:
+ def __init__(self, nc: NextcloudApp, task_id: int, enabled: bool):
+ self.nc = nc
+ self.task_id = task_id
+ self.enabled = enabled
+
+ def send(self, output: dict) -> None:
+ if not self.enabled:
+ return
+
+ try:
+ self.nc.ocs(
+ "POST",
+ f"/ocs/v2.php/taskprocessing/tasks_provider/{self.task_id}/stream-result",
+ json={"output": output},
+ )
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ log(self.nc, LogLvl.WARNING, f"Streaming intermediate task output failed for task {self.task_id}: {e}")
+ self.enabled = False
+
+ def set_progress(self, progress: float) -> bool:
+ try:
+ self.nc.providers.task_processing.set_progress(self.task_id, progress)
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ log(self.nc, LogLvl.WARNING, f"Updating progress failed for task {self.task_id}: {e}")
+ return False
+ return True
+
def log(nc, level, content):
logger.log((level+1)*10, content)
if level < LogLvl.WARNING:
@@ -136,7 +166,12 @@ def background_thread_task():
log(nc, LogLvl.INFO, "Generating reply")
time_start = perf_counter()
log(nc, LogLvl.INFO, task.get("input"))
- result = task_processor(task.get("input"))
+ stream_result = NextcloudTaskStreamResult(nc, task["id"], bool(task.get("preferStreaming", None)))
+ stream_context = StreamContext(
+ stream_result=stream_result.send if stream_result.enabled else None,
+ progress_callback=stream_result.set_progress,
+ )
+ result = task_processor(task.get("input"), context=stream_context)
log(nc, LogLvl.INFO, f"reply generated: {round(float(perf_counter() - time_start), 2)}s")
log(nc, LogLvl.INFO, result)
nc.providers.task_processing.report_result(
diff --git a/lib/proofread.py b/lib/proofread.py
index 5d40b7b..1d4a715 100644
--- a/lib/proofread.py
+++ b/lib/proofread.py
@@ -9,6 +9,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class ProofreadProcessor:
"""
@@ -33,12 +35,12 @@ class ProofreadProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
text=inputs['input']
))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
\ No newline at end of file
+ output = run_runnable_with_streaming(self.runnable, messages, context)
+ return {'output': output}
diff --git a/lib/reformulate.py b/lib/reformulate.py
index 9b2d70d..d3392c0 100644
--- a/lib/reformulate.py
+++ b/lib/reformulate.py
@@ -7,6 +7,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class ReformulateProcessor:
"""
@@ -33,12 +35,12 @@ class ReformulateProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
text=inputs['input']
))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
\ No newline at end of file
+ output = run_runnable_with_streaming(self.runnable, messages, context)
+ return {'output': output}
diff --git a/lib/simplify.py b/lib/simplify.py
index 2abdb26..bff3f32 100644
--- a/lib/simplify.py
+++ b/lib/simplify.py
@@ -7,6 +7,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class SimplifyProcessor:
"""
@@ -36,12 +38,12 @@ class SimplifyProcessor:
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(
text=inputs['input']
))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
+ output = run_runnable_with_streaming(self.runnable, messages, context)
+ return {'output': output}
diff --git a/lib/streaming.py b/lib/streaming.py
new file mode 100644
index 0000000..2cddbb5
--- /dev/null
+++ b/lib/streaming.py
@@ -0,0 +1,154 @@
+# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from time import monotonic
+from typing import Any, Callable
+
+
+def extract_text_content(value: Any) -> str:
+ if value is None:
+ return ""
+
+ content = getattr(value, "content", value)
+
+ if isinstance(content, str):
+ return content
+
+ if isinstance(content, list):
+ parts: list[str] = []
+ for item in content:
+ if isinstance(item, str):
+ parts.append(item)
+ continue
+ if isinstance(item, dict):
+ if isinstance(item.get("text"), str):
+ parts.append(item["text"])
+ elif item.get("type") == "text" and isinstance(item.get("content"), str):
+ parts.append(item["content"])
+ return "".join(parts)
+
+ return str(content)
+
+
+@dataclass
+class StreamContext:
+ stream_result: Callable[[dict[str, Any]], None] | None = None
+ progress_callback: Callable[[float], Any] | None = None
+ stream_interval_seconds: float = 0.75
+ current_output: dict[str, Any] = field(default_factory=dict)
+ _last_emit_at: float = field(default=0.0, init=False)
+ _last_emitted_output: dict[str, Any] | None = field(default=None, init=False)
+
+ @property
+ def enabled(self) -> bool:
+ return self.stream_result is not None
+
+ def update_output(self, output: dict[str, Any] | None = None, *, force: bool = False, **extra: Any) -> None:
+ if output:
+ self.current_output.update(output)
+ if extra:
+ self.current_output.update(extra)
+ self.emit(force=force)
+
+ def update_text(
+ self,
+ text: str,
+ *,
+ key: str = "output",
+ force: bool = False,
+ **extra_output: Any,
+ ) -> None:
+ self.current_output[key] = text
+ if extra_output:
+ self.current_output.update(extra_output)
+ self.emit(force=force)
+
+ def emit(self, *, force: bool = False) -> None:
+ if not self.stream_result:
+ return
+
+ if not self.current_output:
+ return
+
+ now = monotonic()
+ if not force and self._last_emit_at and now - self._last_emit_at < self.stream_interval_seconds:
+ return
+
+ payload = dict(self.current_output)
+ if payload == self._last_emitted_output:
+ return
+
+ self.stream_result(payload)
+ self._last_emit_at = now
+ self._last_emitted_output = payload
+
+ def set_progress(self, progress: float) -> Any:
+ if self.progress_callback is None:
+ return True
+ return self.progress_callback(progress)
+
+
+def run_runnable_with_streaming(
+ runnable: Any,
+ messages: list[Any],
+ context: StreamContext | None = None,
+ *,
+ output_key: str = "output",
+ stream_text_transform: Callable[[str], str] | None = None,
+ stream_payload_transform: Callable[[str], dict[str, Any] | None] | None = None,
+ suppress_empty_stream_updates: bool = False,
+ **extra_output: Any,
+) -> str:
+ if context and context.enabled:
+ chunks: list[str] = []
+
+ for chunk in runnable.stream(messages):
+ text_chunk = extract_text_content(chunk)
+ if not text_chunk:
+ continue
+ chunks.append(text_chunk)
+ output = "".join(chunks)
+ if stream_payload_transform:
+ streamed_payload = stream_payload_transform(output)
+ if streamed_payload is None:
+ continue
+ if extra_output:
+ streamed_payload.update(extra_output)
+ if suppress_empty_stream_updates and not streamed_payload:
+ continue
+ context.update_output(streamed_payload)
+ continue
+
+ streamed_output = stream_text_transform(output) if stream_text_transform else output
+ if suppress_empty_stream_updates and streamed_output == "":
+ continue
+ context.update_text(
+ streamed_output,
+ key=output_key,
+ **extra_output,
+ )
+
+ output = "".join(chunks)
+ if stream_payload_transform:
+ streamed_payload = stream_payload_transform(output)
+ if streamed_payload is None:
+ return output
+ if extra_output:
+ streamed_payload.update(extra_output)
+ if not suppress_empty_stream_updates or streamed_payload:
+ context.update_output(streamed_payload, force=True)
+ return output
+
+ streamed_output = stream_text_transform(output) if stream_text_transform else output
+ context.update_text(
+ streamed_output,
+ key=output_key,
+ force=True,
+ **extra_output,
+ )
+ return output
+
+ return extract_text_content(runnable.invoke(messages))
+
diff --git a/lib/summarize.py b/lib/summarize.py
index ad64d8e..ebcc86b 100644
--- a/lib/summarize.py
+++ b/lib/summarize.py
@@ -8,6 +8,8 @@
from langchain_core.runnables import Runnable
from langchain.text_splitter import RecursiveCharacterTextSplitter
+from streaming import StreamContext, run_runnable_with_streaming
+
class SummarizeProcessor:
runnable: Runnable
@@ -46,32 +48,50 @@ def __init__(self, runnable: Runnable, n_ctx: int = 8000):
length_function=len,
)
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
# Split text if needed
splits = self.text_splitter.split_text(inputs['input'])
+ total_splits = len(splits)
if len(splits) == 1:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(input=splits[0]))
]
- output = self.runnable.invoke(messages)
- return {'output': output.content}
+ output = run_runnable_with_streaming(
+ self.runnable,
+ messages,
+ context,
+ )
+ return {'output': output}
# Process each split
summaries = []
- for split in splits:
+ for index, split in enumerate(splits, start=1):
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format(input=split))
]
output = self.runnable.invoke(messages)
summaries.append(output.content)
+ if context:
+ context.set_progress(index / (total_splits + 1) * 50)
# Merge summaries
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.merge_prompt.format(input="\n\n".join(summaries)))
]
- final_output = self.runnable.invoke(messages)
- return {'output': final_output.content}
\ No newline at end of file
+ if context:
+ context.set_progress(total_splits / (total_splits + 1) * 50 + 50)
+
+ final_output = run_runnable_with_streaming(
+ self.runnable,
+ messages,
+ context,
+ )
+
+ if context:
+ context.set_progress(100)
+
+ return {'output': final_output}
diff --git a/lib/topics.py b/lib/topics.py
index 06ce863..82f25f2 100644
--- a/lib/topics.py
+++ b/lib/topics.py
@@ -10,6 +10,8 @@
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables import Runnable
+from streaming import StreamContext, run_runnable_with_streaming
+
class TopicsProcessor():
runnable: Runnable
@@ -34,10 +36,10 @@ class TopicsProcessor():
def __init__(self, runnable: Runnable):
self.runnable = runnable
- def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
+ def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]:
messages = [
SystemMessage(content=self.system_prompt),
HumanMessage(content=self.user_prompt.format_prompt(text=inputs['input']).to_string())
]
- output = self.runnable.invoke(messages).content
+ output = run_runnable_with_streaming(self.runnable, messages, context)
return {'output': output}
\ No newline at end of file