Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions src/bot/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
should_send_as_photo,
validate_image_path,
)
from .utils.message_buffer import BufferedResult, BufferKey, MessageBuffer

logger = structlog.get_logger()

Expand Down Expand Up @@ -135,6 +136,12 @@ def __init__(self, settings: Settings, deps: Dict[str, Any]):
self.deps = deps
self._active_requests: Dict[int, ActiveRequest] = {}
self._known_commands: frozenset[str] = frozenset()
self._user_locks: Dict[int, asyncio.Lock] = {}
self._message_buffer = MessageBuffer(
chunk_timeout=settings.chunk_buffer_timeout,
chunk_threshold=settings.chunk_buffer_threshold,
on_flush=self._on_buffer_flush,
)

def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg]
"""Wrap handler to inject dependencies into context.bot_data."""
Expand Down Expand Up @@ -912,10 +919,23 @@ async def _send_images(

return caption_sent

def _get_user_lock(self, user_id: int) -> asyncio.Lock:
"""Return a per-user lock, creating one if needed."""
lock = self._user_locks.get(user_id)
if lock is None:
lock = asyncio.Lock()
self._user_locks[user_id] = lock
return lock

async def agentic_text(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Direct Claude passthrough. Simple progress. No suggestions."""
"""Entry point for text messages in agentic mode.

Detects Telegram-split chunks (messages near the 4096-char limit)
and buffers them before processing. Short messages bypass the
buffer and are processed immediately.
"""
user_id = update.effective_user.id
message_text = update.message.text

Expand All @@ -925,14 +945,80 @@ async def agentic_text(
message_length=len(message_text),
)

# Rate limit check
# Rate limit check (runs on every chunk — cheap)
rate_limiter = context.bot_data.get("rate_limiter")
if rate_limiter:
allowed, limit_message = await rate_limiter.check_rate_limit(user_id, 0.001)
if not allowed:
await update.message.reply_text(f"⏱️ {limit_message}")
await update.message.reply_text(f"\u23f1\ufe0f {limit_message}")
return

# --- Chunk buffering -----------------------------------------------
chat_id = update.message.chat.id
thread_id = self._extract_message_thread_id(update)
buf_key: BufferKey = (user_id, chat_id, thread_id)

should_buffer = self._message_buffer.has_buffer(
buf_key
) or self._message_buffer.is_likely_chunk(
message_text, self.settings.chunk_buffer_threshold
)

if should_buffer:
result = await self._message_buffer.add_chunk(
buf_key, message_text, update, context
)
if result is None:
# Chunk buffered, timer pending. Return quickly so
# the sequential lock is released for the next chunk.
return
# Buffer flushed (short tail chunk or single non-chunked message).
message_text = result.combined_text
update = result.first_update
context = result.last_context
if result.chunk_count > 1:
logger.info(
"Chunk buffer flushed inline",
user_id=user_id,
chunk_count=result.chunk_count,
combined_length=len(message_text),
)

# --- Process (may also be called from _on_buffer_flush) ------------
lock = self._get_user_lock(user_id)
async with lock:
await self._process_agentic_text(update, context, message_text)

async def _on_buffer_flush(self, key: BufferKey, result: BufferedResult) -> None:
"""Called by MessageBuffer timer when all chunks have been collected.

Runs as an independent ``asyncio.Task`` — outside the sequential
lock but guarded by a per-user lock.
"""
logger.info(
"Buffer flush via timer",
user_id=key[0],
chunk_count=result.chunk_count,
combined_length=len(result.combined_text),
)
lock = self._get_user_lock(key[0])
async with lock:
await self._process_agentic_text(
result.first_update, result.last_context, result.combined_text
)

async def _process_agentic_text(
self,
update: Update,
context: ContextTypes.DEFAULT_TYPE,
message_text: str,
) -> None:
"""Run *message_text* through Claude and deliver the response.

Extracted from ``agentic_text`` so it can be called both from the
inline handler path and from the timer-fired buffer flush.
"""
user_id = update.effective_user.id
chat = update.message.chat
await chat.send_action("typing")

Expand Down Expand Up @@ -1684,6 +1770,11 @@ async def _handle_stop_callback(
)
return

# Cancel any pending chunk buffer for this user.
for buf_key in self._message_buffer.pending_keys:
if buf_key[0] == target_user_id:
self._message_buffer.cancel(buf_key)

active = self._active_requests.get(target_user_id)
if not active:
await query.answer("Already completed.", show_alert=False)
Expand Down
245 changes: 245 additions & 0 deletions src/bot/utils/message_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""Buffer for Telegram message chunks from multi-part pastes.

When a user pastes text longer than 4096 characters, the Telegram client
silently splits it into multiple messages. This module detects that pattern
(messages at or near the 4096-char limit arriving in quick succession) and
coalesces them into a single string before handing them off for processing.

Design constraints
------------------
* The bot's ``StopAwareUpdateProcessor`` serialises non-priority updates with
a global ``asyncio.Lock``. To allow the *next* chunk to enter the handler
quickly, the handler must **return immediately** when buffering — no blocking
on ``run_command()``.
* Once the debounce timer fires (or a short "tail" chunk is detected), the
combined text is handed to a *flush callback* that runs as an independent
``asyncio.Task``, outside the sequential lock.
"""

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple

import structlog

logger = structlog.get_logger()

# Type alias — keep it simple, avoid NamedTuple overhead.
BufferKey = Tuple[int, int, Optional[int]] # (user_id, chat_id, thread_id)


@dataclass
class BufferedResult:
"""Payload returned when the buffer is ready for processing."""

combined_text: str
first_update: Any # telegram.Update
last_context: Any # ContextTypes.DEFAULT_TYPE
chunk_count: int


@dataclass
class _BufferEntry:
"""Internal mutable state for one pending buffer."""

texts: List[str] = field(default_factory=list)
first_update: Any = None
last_context: Any = None
timer_task: Optional[asyncio.Task[None]] = None
status_message: Any = None # telegram.Message


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

FlushCallback = Callable[[BufferKey, BufferedResult], Coroutine[Any, Any, None]]


class MessageBuffer:
"""Per-key debounce buffer for Telegram message chunks.

Parameters
----------
chunk_timeout:
Seconds to wait after the last chunk before flushing.
chunk_threshold:
Minimum message length (chars) to consider a message a potential
Telegram-split chunk.
on_flush:
Async callback invoked when the buffer flushes via timer.
Signature: ``async def on_flush(key, result) -> None``.
"""

def __init__(
self,
chunk_timeout: float = 0.5,
chunk_threshold: int = 3000,
on_flush: Optional[FlushCallback] = None,
) -> None:
self._buffers: Dict[BufferKey, _BufferEntry] = {}
self._chunk_timeout = chunk_timeout
self._chunk_threshold = chunk_threshold
self._on_flush = on_flush

# -- query ---------------------------------------------------------------

@property
def pending_keys(self) -> List[BufferKey]:
"""Return keys that currently have buffered chunks."""
return list(self._buffers.keys())

def has_buffer(self, key: BufferKey) -> bool:
"""Return True if there is a pending buffer for *key*."""
return key in self._buffers

@staticmethod
def is_likely_chunk(text: str, threshold: int) -> bool:
"""Return True if *text* looks like a Telegram-split chunk."""
return len(text) >= threshold

# -- mutate --------------------------------------------------------------

async def add_chunk(
self,
key: BufferKey,
text: str,
update: Any,
context: Any,
) -> Optional[BufferedResult]:
"""Append a message chunk to the buffer.

Returns
-------
``None``
The chunk was buffered; caller should return immediately.
``BufferedResult``
The buffer has been flushed (short tail chunk detected or
the caller should process the combined text inline).
"""
entry = self._buffers.get(key)

if entry is not None:
# ---- existing buffer: append ----------------------------------
entry.texts.append(text)
entry.last_context = context
self._cancel_timer(entry)

if self.is_likely_chunk(text, self._chunk_threshold):
# Another full-size chunk — more may follow.
self._schedule_timer(key)
return None

# Short chunk → likely the tail. Flush immediately.
return self._pop_result(key)

# ---- no existing buffer -------------------------------------------
if not self.is_likely_chunk(text, self._chunk_threshold):
# Short message, nothing pending — nothing to buffer.
return BufferedResult(
combined_text=text,
first_update=update,
last_context=context,
chunk_count=1,
)

# First full-size chunk — start buffering.
entry = _BufferEntry(
texts=[text],
first_update=update,
last_context=context,
)
self._buffers[key] = entry

# Send a lightweight status message so the user sees feedback.
try:
entry.status_message = await update.message.reply_text(
"Receiving message\u2026"
)
except Exception:
logger.debug("Failed to send buffer status message")

self._schedule_timer(key)
return None

def cancel(self, key: BufferKey) -> Optional[BufferedResult]:
"""Cancel a pending buffer synchronously.

Returns the accumulated result (so the caller can decide whether
to process or discard it), or ``None`` if nothing was buffered.
"""
entry = self._buffers.get(key)
if entry is None:
return None

self._cancel_timer(entry)
result = self._pop_result(key)

# Best-effort cleanup of the status message.
if entry.status_message is not None:
asyncio.ensure_future(self._delete_message(entry.status_message))

return result

# -- internal ------------------------------------------------------------

def _schedule_timer(self, key: BufferKey) -> None:
entry = self._buffers.get(key)
if entry is None:
return
entry.timer_task = asyncio.create_task(self._timer_coro(key))

@staticmethod
def _cancel_timer(entry: _BufferEntry) -> None:
if entry.timer_task is not None and not entry.timer_task.done():
entry.timer_task.cancel()
entry.timer_task = None

async def _timer_coro(self, key: BufferKey) -> None:
"""Sleep for the debounce period, then flush."""
try:
await asyncio.sleep(self._chunk_timeout)
except asyncio.CancelledError:
return

result = self._pop_result(key)
if result is None:
return # entry was already consumed (e.g. by cancel)

logger.info(
"Chunk buffer timer fired",
user_id=key[0],
chunk_count=result.chunk_count,
combined_length=len(result.combined_text),
)

if self._on_flush is not None:
asyncio.create_task(self._on_flush(key, result))

def _pop_result(self, key: BufferKey) -> Optional[BufferedResult]:
"""Remove the entry for *key* and return a ``BufferedResult``."""
entry = self._buffers.pop(key, None)
if entry is None:
return None

self._cancel_timer(entry)

# Best-effort cleanup of the status message.
if entry.status_message is not None:
asyncio.ensure_future(self._delete_message(entry.status_message))

return BufferedResult(
combined_text="".join(entry.texts),
first_update=entry.first_update,
last_context=entry.last_context,
chunk_count=len(entry.texts),
)

@staticmethod
async def _delete_message(msg: Any) -> None:
try:
await msg.delete()
except Exception:
pass
Loading