diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 47a90a7..cf8f30d 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -123,12 +123,21 @@ jobs: pipx install poetry poetry install + - name: Cache llm2 models + uses: actions/cache/restore@v5 + id: cache-llm2-models-restore + env: + cache-name: cache-llm2-models + with: + path: llm2-persistent_storage/ + key: ${{ runner.os }}-llm2-models-${{ env.cache-name }}-${{ hashFiles('llm2/lib/main.py') }} + - name: Install and init backend working-directory: ${{ env.APP_NAME }}/lib env: APP_VERSION: ${{ fromJson(steps.appinfo.outputs.result).version }} run: | - poetry run python3 main.py > ../backend_logs 2>&1 & + APP_PERSISTENT_STORAGE="$(pwd)/../../llm2-persistent_storage/" poetry run python3 main.py > ../backend_logs 2>&1 & - name: Register backend run: | @@ -156,6 +165,43 @@ jobs: curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ] + - name: Cache llm2 models + uses: actions/cache/save@v5 + env: + cache-name: cache-llm2-models + with: + path: llm2-persistent_storage/ + key: ${{ steps.cache-llm2-models-restore.outputs.cache-primary-key }} + + - name: Run streaming task + if: matrix.server-versions == 'master' + env: + CREDS: "admin:password" + run: | + set -x + TASK=$(curl -X POST -u "$CREDS" -H "oCS-APIRequest: true" -H "Content-type: application/json" http://localhost:8080/ocs/v2.php/taskprocessing/schedule?format=json --data-raw '{"input": {"input": "Count from 1 to 20 in words"},"type":"core:text2text", "appId": "test", "customId": "", "preferStreaming": true}') + echo $TASK + TASK_ID=$(echo $TASK | jq '.ocs.data.task.id') + NEXT_WAIT_TIME=0 + TASK_STATUS='"STATUS_SCHEDULED"' + STREAMING_UPDATES=0 + until [ $NEXT_WAIT_TIME -eq 35 ] || [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ] || [ "$TASK_STATUS" == '"STATUS_FAILED"' ]; do + TASK=$(curl -u "$CREDS" -H "oCS-APIRequest: true" http://localhost:8080/ocs/v2.php/taskprocessing/task/$TASK_ID?format=json) + echo $TASK + TASK_STATUS=$(echo $TASK | jq '.ocs.data.task.status') + echo $TASK_STATUS + TASK_OUTPUT=$(echo $TASK | jq -r '.ocs.data.task.output.output // ""') + if [ -n "$TASK_OUTPUT" ] && [ "$TASK_STATUS" != '"STATUS_SUCCESSFUL"' ] && [ "$TASK_STATUS" != '"STATUS_FAILED"' ]; then + STREAMING_UPDATES=$((STREAMING_UPDATES+1)) + echo "Streaming update detected (count: $STREAMING_UPDATES)" + fi + sleep $(( NEXT_WAIT_TIME++ )) + done + echo "Final status: $TASK_STATUS" + echo "Total streaming updates detected: $STREAMING_UPDATES" + [ "$TASK_STATUS" == '"STATUS_SUCCESSFUL"' ] + [ $STREAMING_UPDATES -gt 0 ] + - name: Show logs if: always() run: | diff --git a/default_config/config.json b/default_config/config.json index 9265221..9850a28 100644 --- a/default_config/config.json +++ b/default_config/config.json @@ -61,7 +61,7 @@ "Qwen3.5-9B-Q4_K_M": { "prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n", "loader_config": { - "n_ctx": 16384, + "n_ctx": 24000, "max_tokens": 8192, "stop": ["<|eot_id|>"], "temperature": 0.7 diff --git a/lib/change_tone.py b/lib/change_tone.py index 317fb32..cb2e771 100644 --- a/lib/change_tone.py +++ b/lib/change_tone.py @@ -10,6 +10,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class ChangeToneProcessor: runnable: Runnable @@ -33,10 +35,10 @@ class ChangeToneProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, input_data: dict) -> dict[str, Any]: + def __call__(self, input_data: dict, context: StreamContext | None = None) -> dict[str, Any]: """Process a single input""" messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format_prompt(text=input_data['input'], tone=input_data['tone']).to_string()) ] - return {'output':self.runnable.invoke(messages).content } \ No newline at end of file + return {'output': run_runnable_with_streaming(self.runnable, messages, context)} diff --git a/lib/chat.py b/lib/chat.py index f848b45..d2c2815 100644 --- a/lib/chat.py +++ b/lib/chat.py @@ -3,13 +3,12 @@ """A chat chain """ import json -from typing import Any, Optional +from typing import Any -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.chains.base import Chain -from langchain_community.chat_models import ChatLlamaCpp from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class ChatProcessor: """ @@ -24,10 +23,19 @@ def __init__(self, runner: Runnable): def __call__( self, inputs: dict[str, Any], + context: StreamContext | None = None, ) -> dict[str, str]: system_prompt = inputs['system_prompt'] if inputs.get('memories'): system_prompt += "\n\nYou can remember things from other conversations with the user. If they are relevant, take into account the following memories: \n" + "\n\n".join(inputs['memories']) + "\n\n" - return {'output': self.runnable.invoke( - [('human', system_prompt)] + [(message['role'], message['content']) for message in [json.loads(message) for message in inputs['history']]] + [('human', inputs['input'])] - ).content} \ No newline at end of file + messages = [('human', system_prompt)] + [ + (message['role'], message['content']) + for message in [json.loads(message) for message in inputs['history']] + ] + [('human', inputs['input'])] + return { + 'output': run_runnable_with_streaming( + self.runnable, + messages, + context, + ) + } diff --git a/lib/chatwithtools.py b/lib/chatwithtools.py index 7962162..461b501 100644 --- a/lib/chatwithtools.py +++ b/lib/chatwithtools.py @@ -3,21 +3,34 @@ """A chat chain """ import json +import hashlib import pprint import re -from random import randint from typing import Any from langchain_community.chat_models import ChatLlamaCpp -from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage +from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.messages.ai import AIMessage +from streaming import StreamContext, run_runnable_with_streaming + def generate_tool_call(tool_call: dict): content = '' content += json.dumps({"name": tool_call['name'], "arguments": tool_call['args']}) content += '' return content + +def generate_tool_call_id(tool_call: dict) -> str: + stable_payload = json.dumps( + { + "name": tool_call.get("name"), + "args": tool_call.get("args", tool_call.get("arguments", {})), + }, + sort_keys=True, + ) + return hashlib.sha1(stable_payload.encode("utf-8")).hexdigest()[:16] + def try_parse_tool_calls(content: str): """Try parse the tool calls.""" tool_calls = [] @@ -40,7 +53,9 @@ def try_parse_tool_calls(content: str): func['args'] = func['arguments'] del func['arguments'] if not 'id' in func: - func['id'] = str(randint(1, 10000000000)) + func['id'] = generate_tool_call_id(func) + if 'type' not in func: + func['type'] = 'tool_call' found = True except json.JSONDecodeError as e: print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}") @@ -66,7 +81,9 @@ def try_parse_tool_calls(content: str): func['args'] = func['arguments'] del func['arguments'] if not 'id' in func: - func['id'] = str(randint(1, 10000000000)) + func['id'] = generate_tool_call_id(func) + if 'type' not in func: + func['type'] = 'tool_call' except json.JSONDecodeError as e: print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}") pass @@ -79,6 +96,31 @@ def try_parse_tool_calls(content: str): return {"role": "assistant", "content": c, "tool_calls": tool_calls} return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)} + +def strip_tool_calls_for_streaming(content: str) -> str: + sanitized = re.sub(r".*?", "", content, flags=re.DOTALL) + sanitized = re.sub(r"```tool_call\n.*?\n```", "", sanitized, flags=re.DOTALL) + + partial_markers = [index for index in (sanitized.find("$", "", sanitized) + + +def build_streaming_payload(content: str) -> dict[str, Any] | None: + payload: dict[str, Any] = {} + cleaned_output = strip_tool_calls_for_streaming(content) + parsed_response = try_parse_tool_calls(content) + tool_calls = parsed_response.get('tool_calls') + + if cleaned_output or tool_calls: + payload['output'] = cleaned_output + if tool_calls: + payload['tool_calls'] = json.dumps(tool_calls) + + return payload or None + class ChatWithToolsProcessor: """ A chat with tools processor that supports batch processing @@ -89,7 +131,7 @@ class ChatWithToolsProcessor: def __init__(self, runner: ChatLlamaCpp): self.model = runner - def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]: + def _process_single_input(self, input_data: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: system_prompt = """ {downstream_system_prompt} @@ -150,15 +192,20 @@ def _process_single_input(self, input_data: dict[str, Any]) -> dict[str, Any]: messages.append(HumanMessage(content='')) pprint.pprint(messages) - response = self.model.invoke(messages) + response_content = run_runnable_with_streaming( + self.model, + messages, + context, + stream_payload_transform=build_streaming_payload, + suppress_empty_stream_updates=True, + ) - #if not response.tool_calls or len(response.tool_calls) == 0: - response = AIMessage(**try_parse_tool_calls(response.content)) + response = AIMessage(**try_parse_tool_calls(response_content)) return { 'output': response.content, 'tool_calls': json.dumps(response.tool_calls) } - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: - return self._process_single_input(inputs) \ No newline at end of file + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: + return self._process_single_input(inputs, context) diff --git a/lib/contextwrite.py b/lib/contextwrite.py index 8e8ef08..8fe62fc 100644 --- a/lib/contextwrite.py +++ b/lib/contextwrite.py @@ -10,6 +10,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class ContextWriteProcessor: runnable: Runnable @@ -36,7 +38,7 @@ class ContextWriteProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format( @@ -44,5 +46,5 @@ def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: source_input=inputs['source_input'] )) ] - output = self.runnable.invoke(messages) - return {'output': output.content} \ No newline at end of file + output = run_runnable_with_streaming(self.runnable, messages, context) + return {'output': output} diff --git a/lib/free_prompt.py b/lib/free_prompt.py index 4f2cdd2..cb5ca25 100644 --- a/lib/free_prompt.py +++ b/lib/free_prompt.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later -from typing import Any, List +from typing import Any from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class FreePromptProcessor: """ @@ -20,9 +22,10 @@ def __init__(self, runnable: Runnable): def __call__( self, inputs: dict[str, Any], + context: StreamContext | None = None, ) -> dict[str, Any]: - output = self.runnable.invoke([ + output = run_runnable_with_streaming(self.runnable, [ SystemMessage(self.system_prompt), HumanMessage(inputs['input']) - ]).content + ], context) return {'output': output} \ No newline at end of file diff --git a/lib/headline.py b/lib/headline.py index dc8a6d7..ac88832 100644 --- a/lib/headline.py +++ b/lib/headline.py @@ -9,6 +9,8 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class HeadlineProcessor: """ @@ -33,12 +35,12 @@ class HeadlineProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format( text=inputs['input'] )) ] - output = self.runnable.invoke(messages) - return {'output': output.content} + output = run_runnable_with_streaming(self.runnable, messages, context) + return {'output': output} diff --git a/lib/main.py b/lib/main.py index 65c9279..652b200 100644 --- a/lib/main.py +++ b/lib/main.py @@ -13,6 +13,7 @@ from time import perf_counter, sleep, strftime from niquests import RequestException +from streaming import StreamContext from task_processors import generate_task_processors from fastapi import FastAPI from nc_py_api import AsyncNextcloudApp, NextcloudApp, NextcloudException @@ -24,6 +25,35 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + +class NextcloudTaskStreamResult: + def __init__(self, nc: NextcloudApp, task_id: int, enabled: bool): + self.nc = nc + self.task_id = task_id + self.enabled = enabled + + def send(self, output: dict) -> None: + if not self.enabled: + return + + try: + self.nc.ocs( + "POST", + f"/ocs/v2.php/taskprocessing/tasks_provider/{self.task_id}/stream-result", + json={"output": output}, + ) + except (NextcloudException, RequestException, JSONDecodeError) as e: + log(self.nc, LogLvl.WARNING, f"Streaming intermediate task output failed for task {self.task_id}: {e}") + self.enabled = False + + def set_progress(self, progress: float) -> bool: + try: + self.nc.providers.task_processing.set_progress(self.task_id, progress) + except (NextcloudException, RequestException, JSONDecodeError) as e: + log(self.nc, LogLvl.WARNING, f"Updating progress failed for task {self.task_id}: {e}") + return False + return True + def log(nc, level, content): logger.log((level+1)*10, content) if level < LogLvl.WARNING: @@ -136,7 +166,12 @@ def background_thread_task(): log(nc, LogLvl.INFO, "Generating reply") time_start = perf_counter() log(nc, LogLvl.INFO, task.get("input")) - result = task_processor(task.get("input")) + stream_result = NextcloudTaskStreamResult(nc, task["id"], bool(task.get("preferStreaming", None))) + stream_context = StreamContext( + stream_result=stream_result.send if stream_result.enabled else None, + progress_callback=stream_result.set_progress, + ) + result = task_processor(task.get("input"), context=stream_context) log(nc, LogLvl.INFO, f"reply generated: {round(float(perf_counter() - time_start), 2)}s") log(nc, LogLvl.INFO, result) nc.providers.task_processing.report_result( diff --git a/lib/proofread.py b/lib/proofread.py index 5d40b7b..1d4a715 100644 --- a/lib/proofread.py +++ b/lib/proofread.py @@ -9,6 +9,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class ProofreadProcessor: """ @@ -33,12 +35,12 @@ class ProofreadProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format( text=inputs['input'] )) ] - output = self.runnable.invoke(messages) - return {'output': output.content} \ No newline at end of file + output = run_runnable_with_streaming(self.runnable, messages, context) + return {'output': output} diff --git a/lib/reformulate.py b/lib/reformulate.py index 9b2d70d..d3392c0 100644 --- a/lib/reformulate.py +++ b/lib/reformulate.py @@ -7,6 +7,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class ReformulateProcessor: """ @@ -33,12 +35,12 @@ class ReformulateProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format( text=inputs['input'] )) ] - output = self.runnable.invoke(messages) - return {'output': output.content} \ No newline at end of file + output = run_runnable_with_streaming(self.runnable, messages, context) + return {'output': output} diff --git a/lib/simplify.py b/lib/simplify.py index 2abdb26..bff3f32 100644 --- a/lib/simplify.py +++ b/lib/simplify.py @@ -7,6 +7,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class SimplifyProcessor: """ @@ -36,12 +38,12 @@ class SimplifyProcessor: def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format( text=inputs['input'] )) ] - output = self.runnable.invoke(messages) - return {'output': output.content} + output = run_runnable_with_streaming(self.runnable, messages, context) + return {'output': output} diff --git a/lib/streaming.py b/lib/streaming.py new file mode 100644 index 0000000..2cddbb5 --- /dev/null +++ b/lib/streaming.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors +# SPDX-License-Identifier: AGPL-3.0-or-later +from __future__ import annotations + +from dataclasses import dataclass, field +from time import monotonic +from typing import Any, Callable + + +def extract_text_content(value: Any) -> str: + if value is None: + return "" + + content = getattr(value, "content", value) + + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + continue + if isinstance(item, dict): + if isinstance(item.get("text"), str): + parts.append(item["text"]) + elif item.get("type") == "text" and isinstance(item.get("content"), str): + parts.append(item["content"]) + return "".join(parts) + + return str(content) + + +@dataclass +class StreamContext: + stream_result: Callable[[dict[str, Any]], None] | None = None + progress_callback: Callable[[float], Any] | None = None + stream_interval_seconds: float = 0.75 + current_output: dict[str, Any] = field(default_factory=dict) + _last_emit_at: float = field(default=0.0, init=False) + _last_emitted_output: dict[str, Any] | None = field(default=None, init=False) + + @property + def enabled(self) -> bool: + return self.stream_result is not None + + def update_output(self, output: dict[str, Any] | None = None, *, force: bool = False, **extra: Any) -> None: + if output: + self.current_output.update(output) + if extra: + self.current_output.update(extra) + self.emit(force=force) + + def update_text( + self, + text: str, + *, + key: str = "output", + force: bool = False, + **extra_output: Any, + ) -> None: + self.current_output[key] = text + if extra_output: + self.current_output.update(extra_output) + self.emit(force=force) + + def emit(self, *, force: bool = False) -> None: + if not self.stream_result: + return + + if not self.current_output: + return + + now = monotonic() + if not force and self._last_emit_at and now - self._last_emit_at < self.stream_interval_seconds: + return + + payload = dict(self.current_output) + if payload == self._last_emitted_output: + return + + self.stream_result(payload) + self._last_emit_at = now + self._last_emitted_output = payload + + def set_progress(self, progress: float) -> Any: + if self.progress_callback is None: + return True + return self.progress_callback(progress) + + +def run_runnable_with_streaming( + runnable: Any, + messages: list[Any], + context: StreamContext | None = None, + *, + output_key: str = "output", + stream_text_transform: Callable[[str], str] | None = None, + stream_payload_transform: Callable[[str], dict[str, Any] | None] | None = None, + suppress_empty_stream_updates: bool = False, + **extra_output: Any, +) -> str: + if context and context.enabled: + chunks: list[str] = [] + + for chunk in runnable.stream(messages): + text_chunk = extract_text_content(chunk) + if not text_chunk: + continue + chunks.append(text_chunk) + output = "".join(chunks) + if stream_payload_transform: + streamed_payload = stream_payload_transform(output) + if streamed_payload is None: + continue + if extra_output: + streamed_payload.update(extra_output) + if suppress_empty_stream_updates and not streamed_payload: + continue + context.update_output(streamed_payload) + continue + + streamed_output = stream_text_transform(output) if stream_text_transform else output + if suppress_empty_stream_updates and streamed_output == "": + continue + context.update_text( + streamed_output, + key=output_key, + **extra_output, + ) + + output = "".join(chunks) + if stream_payload_transform: + streamed_payload = stream_payload_transform(output) + if streamed_payload is None: + return output + if extra_output: + streamed_payload.update(extra_output) + if not suppress_empty_stream_updates or streamed_payload: + context.update_output(streamed_payload, force=True) + return output + + streamed_output = stream_text_transform(output) if stream_text_transform else output + context.update_text( + streamed_output, + key=output_key, + force=True, + **extra_output, + ) + return output + + return extract_text_content(runnable.invoke(messages)) + diff --git a/lib/summarize.py b/lib/summarize.py index ad64d8e..ebcc86b 100644 --- a/lib/summarize.py +++ b/lib/summarize.py @@ -8,6 +8,8 @@ from langchain_core.runnables import Runnable from langchain.text_splitter import RecursiveCharacterTextSplitter +from streaming import StreamContext, run_runnable_with_streaming + class SummarizeProcessor: runnable: Runnable @@ -46,32 +48,50 @@ def __init__(self, runnable: Runnable, n_ctx: int = 8000): length_function=len, ) - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: # Split text if needed splits = self.text_splitter.split_text(inputs['input']) + total_splits = len(splits) if len(splits) == 1: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format(input=splits[0])) ] - output = self.runnable.invoke(messages) - return {'output': output.content} + output = run_runnable_with_streaming( + self.runnable, + messages, + context, + ) + return {'output': output} # Process each split summaries = [] - for split in splits: + for index, split in enumerate(splits, start=1): messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format(input=split)) ] output = self.runnable.invoke(messages) summaries.append(output.content) + if context: + context.set_progress(index / (total_splits + 1) * 50) # Merge summaries messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.merge_prompt.format(input="\n\n".join(summaries))) ] - final_output = self.runnable.invoke(messages) - return {'output': final_output.content} \ No newline at end of file + if context: + context.set_progress(total_splits / (total_splits + 1) * 50 + 50) + + final_output = run_runnable_with_streaming( + self.runnable, + messages, + context, + ) + + if context: + context.set_progress(100) + + return {'output': final_output} diff --git a/lib/topics.py b/lib/topics.py index 06ce863..82f25f2 100644 --- a/lib/topics.py +++ b/lib/topics.py @@ -10,6 +10,8 @@ from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import Runnable +from streaming import StreamContext, run_runnable_with_streaming + class TopicsProcessor(): runnable: Runnable @@ -34,10 +36,10 @@ class TopicsProcessor(): def __init__(self, runnable: Runnable): self.runnable = runnable - def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + def __call__(self, inputs: dict[str, Any], context: StreamContext | None = None) -> dict[str, Any]: messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=self.user_prompt.format_prompt(text=inputs['input']).to_string()) ] - output = self.runnable.invoke(messages).content + output = run_runnable_with_streaming(self.runnable, messages, context) return {'output': output} \ No newline at end of file