diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 6d9719f0..53770631 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -39,6 +39,7 @@ should_send_as_photo, validate_image_path, ) +from .utils.message_buffer import BufferedResult, BufferKey, MessageBuffer logger = structlog.get_logger() @@ -135,6 +136,12 @@ def __init__(self, settings: Settings, deps: Dict[str, Any]): self.deps = deps self._active_requests: Dict[int, ActiveRequest] = {} self._known_commands: frozenset[str] = frozenset() + self._user_locks: Dict[int, asyncio.Lock] = {} + self._message_buffer = MessageBuffer( + chunk_timeout=settings.chunk_buffer_timeout, + chunk_threshold=settings.chunk_buffer_threshold, + on_flush=self._on_buffer_flush, + ) def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg] """Wrap handler to inject dependencies into context.bot_data.""" @@ -912,10 +919,23 @@ async def _send_images( return caption_sent + def _get_user_lock(self, user_id: int) -> asyncio.Lock: + """Return a per-user lock, creating one if needed.""" + lock = self._user_locks.get(user_id) + if lock is None: + lock = asyncio.Lock() + self._user_locks[user_id] = lock + return lock + async def agentic_text( self, update: Update, context: ContextTypes.DEFAULT_TYPE ) -> None: - """Direct Claude passthrough. Simple progress. No suggestions.""" + """Entry point for text messages in agentic mode. + + Detects Telegram-split chunks (messages near the 4096-char limit) + and buffers them before processing. Short messages bypass the + buffer and are processed immediately. + """ user_id = update.effective_user.id message_text = update.message.text @@ -925,14 +945,80 @@ async def agentic_text( message_length=len(message_text), ) - # Rate limit check + # Rate limit check (runs on every chunk — cheap) rate_limiter = context.bot_data.get("rate_limiter") if rate_limiter: allowed, limit_message = await rate_limiter.check_rate_limit(user_id, 0.001) if not allowed: - await update.message.reply_text(f"⏱️ {limit_message}") + await update.message.reply_text(f"\u23f1\ufe0f {limit_message}") return + # --- Chunk buffering ----------------------------------------------- + chat_id = update.message.chat.id + thread_id = self._extract_message_thread_id(update) + buf_key: BufferKey = (user_id, chat_id, thread_id) + + should_buffer = self._message_buffer.has_buffer( + buf_key + ) or self._message_buffer.is_likely_chunk( + message_text, self.settings.chunk_buffer_threshold + ) + + if should_buffer: + result = await self._message_buffer.add_chunk( + buf_key, message_text, update, context + ) + if result is None: + # Chunk buffered, timer pending. Return quickly so + # the sequential lock is released for the next chunk. + return + # Buffer flushed (short tail chunk or single non-chunked message). + message_text = result.combined_text + update = result.first_update + context = result.last_context + if result.chunk_count > 1: + logger.info( + "Chunk buffer flushed inline", + user_id=user_id, + chunk_count=result.chunk_count, + combined_length=len(message_text), + ) + + # --- Process (may also be called from _on_buffer_flush) ------------ + lock = self._get_user_lock(user_id) + async with lock: + await self._process_agentic_text(update, context, message_text) + + async def _on_buffer_flush(self, key: BufferKey, result: BufferedResult) -> None: + """Called by MessageBuffer timer when all chunks have been collected. + + Runs as an independent ``asyncio.Task`` — outside the sequential + lock but guarded by a per-user lock. + """ + logger.info( + "Buffer flush via timer", + user_id=key[0], + chunk_count=result.chunk_count, + combined_length=len(result.combined_text), + ) + lock = self._get_user_lock(key[0]) + async with lock: + await self._process_agentic_text( + result.first_update, result.last_context, result.combined_text + ) + + async def _process_agentic_text( + self, + update: Update, + context: ContextTypes.DEFAULT_TYPE, + message_text: str, + ) -> None: + """Run *message_text* through Claude and deliver the response. + + Extracted from ``agentic_text`` so it can be called both from the + inline handler path and from the timer-fired buffer flush. + """ + user_id = update.effective_user.id chat = update.message.chat await chat.send_action("typing") @@ -1684,6 +1770,11 @@ async def _handle_stop_callback( ) return + # Cancel any pending chunk buffer for this user. + for buf_key in self._message_buffer.pending_keys: + if buf_key[0] == target_user_id: + self._message_buffer.cancel(buf_key) + active = self._active_requests.get(target_user_id) if not active: await query.answer("Already completed.", show_alert=False) diff --git a/src/bot/utils/message_buffer.py b/src/bot/utils/message_buffer.py new file mode 100644 index 00000000..79240d51 --- /dev/null +++ b/src/bot/utils/message_buffer.py @@ -0,0 +1,245 @@ +"""Buffer for Telegram message chunks from multi-part pastes. + +When a user pastes text longer than 4096 characters, the Telegram client +silently splits it into multiple messages. This module detects that pattern +(messages at or near the 4096-char limit arriving in quick succession) and +coalesces them into a single string before handing them off for processing. + +Design constraints +------------------ +* The bot's ``StopAwareUpdateProcessor`` serialises non-priority updates with + a global ``asyncio.Lock``. To allow the *next* chunk to enter the handler + quickly, the handler must **return immediately** when buffering — no blocking + on ``run_command()``. +* Once the debounce timer fires (or a short "tail" chunk is detected), the + combined text is handed to a *flush callback* that runs as an independent + ``asyncio.Task``, outside the sequential lock. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple + +import structlog + +logger = structlog.get_logger() + +# Type alias — keep it simple, avoid NamedTuple overhead. +BufferKey = Tuple[int, int, Optional[int]] # (user_id, chat_id, thread_id) + + +@dataclass +class BufferedResult: + """Payload returned when the buffer is ready for processing.""" + + combined_text: str + first_update: Any # telegram.Update + last_context: Any # ContextTypes.DEFAULT_TYPE + chunk_count: int + + +@dataclass +class _BufferEntry: + """Internal mutable state for one pending buffer.""" + + texts: List[str] = field(default_factory=list) + first_update: Any = None + last_context: Any = None + timer_task: Optional[asyncio.Task[None]] = None + status_message: Any = None # telegram.Message + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +FlushCallback = Callable[[BufferKey, BufferedResult], Coroutine[Any, Any, None]] + + +class MessageBuffer: + """Per-key debounce buffer for Telegram message chunks. + + Parameters + ---------- + chunk_timeout: + Seconds to wait after the last chunk before flushing. + chunk_threshold: + Minimum message length (chars) to consider a message a potential + Telegram-split chunk. + on_flush: + Async callback invoked when the buffer flushes via timer. + Signature: ``async def on_flush(key, result) -> None``. + """ + + def __init__( + self, + chunk_timeout: float = 0.5, + chunk_threshold: int = 3000, + on_flush: Optional[FlushCallback] = None, + ) -> None: + self._buffers: Dict[BufferKey, _BufferEntry] = {} + self._chunk_timeout = chunk_timeout + self._chunk_threshold = chunk_threshold + self._on_flush = on_flush + + # -- query --------------------------------------------------------------- + + @property + def pending_keys(self) -> List[BufferKey]: + """Return keys that currently have buffered chunks.""" + return list(self._buffers.keys()) + + def has_buffer(self, key: BufferKey) -> bool: + """Return True if there is a pending buffer for *key*.""" + return key in self._buffers + + @staticmethod + def is_likely_chunk(text: str, threshold: int) -> bool: + """Return True if *text* looks like a Telegram-split chunk.""" + return len(text) >= threshold + + # -- mutate -------------------------------------------------------------- + + async def add_chunk( + self, + key: BufferKey, + text: str, + update: Any, + context: Any, + ) -> Optional[BufferedResult]: + """Append a message chunk to the buffer. + + Returns + ------- + ``None`` + The chunk was buffered; caller should return immediately. + ``BufferedResult`` + The buffer has been flushed (short tail chunk detected or + the caller should process the combined text inline). + """ + entry = self._buffers.get(key) + + if entry is not None: + # ---- existing buffer: append ---------------------------------- + entry.texts.append(text) + entry.last_context = context + self._cancel_timer(entry) + + if self.is_likely_chunk(text, self._chunk_threshold): + # Another full-size chunk — more may follow. + self._schedule_timer(key) + return None + + # Short chunk → likely the tail. Flush immediately. + return self._pop_result(key) + + # ---- no existing buffer ------------------------------------------- + if not self.is_likely_chunk(text, self._chunk_threshold): + # Short message, nothing pending — nothing to buffer. + return BufferedResult( + combined_text=text, + first_update=update, + last_context=context, + chunk_count=1, + ) + + # First full-size chunk — start buffering. + entry = _BufferEntry( + texts=[text], + first_update=update, + last_context=context, + ) + self._buffers[key] = entry + + # Send a lightweight status message so the user sees feedback. + try: + entry.status_message = await update.message.reply_text( + "Receiving message\u2026" + ) + except Exception: + logger.debug("Failed to send buffer status message") + + self._schedule_timer(key) + return None + + def cancel(self, key: BufferKey) -> Optional[BufferedResult]: + """Cancel a pending buffer synchronously. + + Returns the accumulated result (so the caller can decide whether + to process or discard it), or ``None`` if nothing was buffered. + """ + entry = self._buffers.get(key) + if entry is None: + return None + + self._cancel_timer(entry) + result = self._pop_result(key) + + # Best-effort cleanup of the status message. + if entry.status_message is not None: + asyncio.ensure_future(self._delete_message(entry.status_message)) + + return result + + # -- internal ------------------------------------------------------------ + + def _schedule_timer(self, key: BufferKey) -> None: + entry = self._buffers.get(key) + if entry is None: + return + entry.timer_task = asyncio.create_task(self._timer_coro(key)) + + @staticmethod + def _cancel_timer(entry: _BufferEntry) -> None: + if entry.timer_task is not None and not entry.timer_task.done(): + entry.timer_task.cancel() + entry.timer_task = None + + async def _timer_coro(self, key: BufferKey) -> None: + """Sleep for the debounce period, then flush.""" + try: + await asyncio.sleep(self._chunk_timeout) + except asyncio.CancelledError: + return + + result = self._pop_result(key) + if result is None: + return # entry was already consumed (e.g. by cancel) + + logger.info( + "Chunk buffer timer fired", + user_id=key[0], + chunk_count=result.chunk_count, + combined_length=len(result.combined_text), + ) + + if self._on_flush is not None: + asyncio.create_task(self._on_flush(key, result)) + + def _pop_result(self, key: BufferKey) -> Optional[BufferedResult]: + """Remove the entry for *key* and return a ``BufferedResult``.""" + entry = self._buffers.pop(key, None) + if entry is None: + return None + + self._cancel_timer(entry) + + # Best-effort cleanup of the status message. + if entry.status_message is not None: + asyncio.ensure_future(self._delete_message(entry.status_message)) + + return BufferedResult( + combined_text="".join(entry.texts), + first_update=entry.first_update, + last_context=entry.last_context, + chunk_count=len(entry.texts), + ) + + @staticmethod + async def _delete_message(msg: Any) -> None: + try: + await msg.delete() + except Exception: + pass diff --git a/src/config/settings.py b/src/config/settings.py index c4f7cb18..c0cfa5ff 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -281,6 +281,32 @@ class Settings(BaseSettings): le=5.0, ) + # Message chunk buffering (multi-part paste detection) + chunk_buffer_timeout: float = Field( + 0.5, + description=( + "Seconds to wait for additional message chunks before processing. " + "Applies when a pasted message exceeds Telegram's 4096-char limit." + ), + ge=0.1, + le=3.0, + ) + chunk_buffer_threshold: int = Field( + 3000, + description=( + "Minimum message length (chars) to trigger chunk buffering. " + "Messages shorter than this are never buffered. " + "Telegram clients often split long pastes at paragraph/line " + "boundaries rather than at the 4096-char hard limit, so chunks " + "can arrive well below 4000 chars — the default of 3000 catches " + "those splits. Lower this further if you see multi-part pastes " + "slipping through; raise it if you often send single long " + "messages that don't need buffering." + ), + ge=2000, + le=4096, + ) + # Monitoring log_level: str = Field("INFO", description="Logging level") enable_telemetry: bool = Field(False, description="Enable anonymous telemetry") diff --git a/src/utils/constants.py b/src/utils/constants.py index 7b66f9a6..2c6cf66f 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -91,5 +91,13 @@ DEFAULT_RETRY_BACKOFF_FACTOR = 3.0 DEFAULT_RETRY_MAX_DELAY = 30.0 +# Message chunk buffering (multi-part paste detection) +DEFAULT_CHUNK_BUFFER_TIMEOUT = 0.5 +# Telegram clients often split long pastes at paragraph/line boundaries +# rather than at the 4096-char hard limit, so chunks can land well below +# 4096. 3000 catches the common split points while still being well above +# normal chat messages. +DEFAULT_CHUNK_BUFFER_THRESHOLD = 3000 + # Logging LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/tests/unit/test_bot/test_message_buffer.py b/tests/unit/test_bot/test_message_buffer.py new file mode 100644 index 00000000..848a59b7 --- /dev/null +++ b/tests/unit/test_bot/test_message_buffer.py @@ -0,0 +1,290 @@ +"""Tests for the MessageBuffer (multi-part paste detection).""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from src.bot.utils.message_buffer import ( + BufferedResult, + BufferKey, + MessageBuffer, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_update(user_id: int = 1, chat_id: int = 100, thread_id: int = 0) -> MagicMock: + update = MagicMock() + update.effective_user.id = user_id + update.message.chat.id = chat_id + update.message.message_thread_id = thread_id + update.message.message_id = 42 + update.message.reply_text = AsyncMock(return_value=MagicMock(delete=AsyncMock())) + return update + + +def _make_context() -> MagicMock: + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {} + return ctx + + +KEY: BufferKey = (1, 100, None) +LONG_TEXT = "x" * 4096 +SHORT_TEXT = "hello" + + +# --------------------------------------------------------------------------- +# is_likely_chunk +# --------------------------------------------------------------------------- + + +class TestIsLikelyChunk: + def test_short_message(self) -> None: + assert MessageBuffer.is_likely_chunk("hi", 4000) is False + + def test_exactly_threshold(self) -> None: + assert MessageBuffer.is_likely_chunk("a" * 4000, 4000) is True + + def test_above_threshold(self) -> None: + assert MessageBuffer.is_likely_chunk("a" * 4096, 4000) is True + + def test_below_threshold(self) -> None: + assert MessageBuffer.is_likely_chunk("a" * 3999, 4000) is False + + def test_custom_threshold(self) -> None: + assert MessageBuffer.is_likely_chunk("a" * 3000, 3000) is True + + +# --------------------------------------------------------------------------- +# add_chunk — first chunk +# --------------------------------------------------------------------------- + + +class TestAddChunkFirstChunk: + async def test_short_message_returns_result_immediately(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + update = _make_update() + ctx = _make_context() + + result = await buf.add_chunk(KEY, SHORT_TEXT, update, ctx) + + assert result is not None + assert isinstance(result, BufferedResult) + assert result.combined_text == SHORT_TEXT + assert result.chunk_count == 1 + assert result.first_update is update + # No entry in buffer + assert not buf.has_buffer(KEY) + + async def test_long_chunk_returns_none(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + update = _make_update() + ctx = _make_context() + + result = await buf.add_chunk(KEY, LONG_TEXT, update, ctx) + + assert result is None + assert buf.has_buffer(KEY) + + async def test_long_chunk_sends_status_message(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + update = _make_update() + ctx = _make_context() + + await buf.add_chunk(KEY, LONG_TEXT, update, ctx) + + update.message.reply_text.assert_called_once_with("Receiving message\u2026") + + +# --------------------------------------------------------------------------- +# add_chunk — subsequent chunks +# --------------------------------------------------------------------------- + + +class TestAddChunkSubsequent: + async def test_second_long_chunk_stays_buffered(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + u1, u2 = _make_update(), _make_update() + c1, c2 = _make_context(), _make_context() + + await buf.add_chunk(KEY, LONG_TEXT, u1, c1) + result = await buf.add_chunk(KEY, LONG_TEXT, u2, c2) + + assert result is None + assert buf.has_buffer(KEY) + + async def test_short_tail_flushes_immediately(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + u1, u2 = _make_update(), _make_update() + c1, c2 = _make_context(), _make_context() + + await buf.add_chunk(KEY, LONG_TEXT, u1, c1) + result = await buf.add_chunk(KEY, SHORT_TEXT, u2, c2) + + assert result is not None + assert result.combined_text == LONG_TEXT + SHORT_TEXT + assert result.chunk_count == 2 + assert result.first_update is u1 + assert result.last_context is c2 + # Buffer consumed + assert not buf.has_buffer(KEY) + + +# --------------------------------------------------------------------------- +# BufferedResult — concatenation semantics +# --------------------------------------------------------------------------- + + +class TestBufferedResult: + async def test_no_delimiter_between_chunks(self) -> None: + """Telegram splits at char boundaries — join with empty string.""" + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + u1, u2, u3 = _make_update(), _make_update(), _make_update() + c1, c2, c3 = _make_context(), _make_context(), _make_context() + + await buf.add_chunk(KEY, "AAAA" * 1000, u1, c1) + await buf.add_chunk(KEY, "BBBB" * 1000, u2, c2) + result = await buf.add_chunk(KEY, "CC", u3, c3) + + assert result is not None + assert result.combined_text == "AAAA" * 1000 + "BBBB" * 1000 + "CC" + assert result.chunk_count == 3 + + +# --------------------------------------------------------------------------- +# Timer flush +# --------------------------------------------------------------------------- + + +class TestTimerFlush: + async def test_timer_fires_calls_on_flush(self) -> None: + flush_called: list[tuple[BufferKey, BufferedResult]] = [] + + async def _on_flush(key: BufferKey, result: BufferedResult) -> None: + flush_called.append((key, result)) + + buf = MessageBuffer(chunk_timeout=0.1, chunk_threshold=4000, on_flush=_on_flush) + update = _make_update() + ctx = _make_context() + + await buf.add_chunk(KEY, LONG_TEXT, update, ctx) + # Wait for timer + a margin + await asyncio.sleep(0.25) + + assert len(flush_called) == 1 + key, result = flush_called[0] + assert key == KEY + assert result.combined_text == LONG_TEXT + assert result.chunk_count == 1 + assert not buf.has_buffer(KEY) + + async def test_timer_resets_on_new_chunk(self) -> None: + flush_called: list[tuple[BufferKey, BufferedResult]] = [] + + async def _on_flush(key: BufferKey, result: BufferedResult) -> None: + flush_called.append((key, result)) + + buf = MessageBuffer( + chunk_timeout=0.15, chunk_threshold=4000, on_flush=_on_flush + ) + u1, u2 = _make_update(), _make_update() + c1, c2 = _make_context(), _make_context() + + await buf.add_chunk(KEY, LONG_TEXT, u1, c1) + await asyncio.sleep(0.08) # Before first timer fires + await buf.add_chunk(KEY, LONG_TEXT, u2, c2) + await asyncio.sleep(0.08) # First timer would have fired if not reset + assert len(flush_called) == 0 # Still buffered + await asyncio.sleep(0.12) # Now second timer fires + + assert len(flush_called) == 1 + assert flush_called[0][1].combined_text == LONG_TEXT + LONG_TEXT + assert flush_called[0][1].chunk_count == 2 + + +# --------------------------------------------------------------------------- +# cancel +# --------------------------------------------------------------------------- + + +class TestCancel: + async def test_cancel_removes_entry(self) -> None: + buf = MessageBuffer(chunk_timeout=1.0, chunk_threshold=4000) + await buf.add_chunk(KEY, LONG_TEXT, _make_update(), _make_context()) + + result = buf.cancel(KEY) + + assert result is not None + assert result.combined_text == LONG_TEXT + assert not buf.has_buffer(KEY) + + async def test_cancel_nonexistent_key(self) -> None: + buf = MessageBuffer(chunk_timeout=1.0, chunk_threshold=4000) + result = buf.cancel(KEY) + assert result is None + + async def test_cancel_prevents_timer_flush(self) -> None: + flush_called: list[BufferedResult] = [] + + async def _on_flush(_key: BufferKey, result: BufferedResult) -> None: + flush_called.append(result) + + buf = MessageBuffer(chunk_timeout=0.1, chunk_threshold=4000, on_flush=_on_flush) + await buf.add_chunk(KEY, LONG_TEXT, _make_update(), _make_context()) + buf.cancel(KEY) + + await asyncio.sleep(0.2) + assert len(flush_called) == 0 + + +# --------------------------------------------------------------------------- +# Independence — different keys +# --------------------------------------------------------------------------- + + +class TestKeyIndependence: + async def test_different_users_independent(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + key_a: BufferKey = (1, 100, None) + key_b: BufferKey = (2, 100, None) + + await buf.add_chunk(key_a, LONG_TEXT, _make_update(user_id=1), _make_context()) + await buf.add_chunk(key_b, LONG_TEXT, _make_update(user_id=2), _make_context()) + + assert buf.has_buffer(key_a) + assert buf.has_buffer(key_b) + + # Flushing one doesn't affect the other + result = buf.cancel(key_a) + assert result is not None + assert buf.has_buffer(key_b) + + async def test_different_threads_independent(self) -> None: + buf = MessageBuffer(chunk_timeout=0.5, chunk_threshold=4000) + key_a: BufferKey = (1, 100, 10) + key_b: BufferKey = (1, 100, 20) + + await buf.add_chunk(key_a, LONG_TEXT, _make_update(), _make_context()) + await buf.add_chunk(key_b, LONG_TEXT, _make_update(), _make_context()) + + assert buf.has_buffer(key_a) + assert buf.has_buffer(key_b) + + +# --------------------------------------------------------------------------- +# pending_keys +# --------------------------------------------------------------------------- + + +class TestPendingKeys: + async def test_pending_keys_empty_initially(self) -> None: + buf = MessageBuffer() + assert buf.pending_keys == [] + + async def test_pending_keys_after_buffering(self) -> None: + buf = MessageBuffer(chunk_timeout=1.0, chunk_threshold=4000) + await buf.add_chunk(KEY, LONG_TEXT, _make_update(), _make_context()) + assert KEY in buf.pending_keys diff --git a/tests/unit/test_bot/test_middleware.py b/tests/unit/test_bot/test_middleware.py index 4ff58365..69cb8885 100644 --- a/tests/unit/test_bot/test_middleware.py +++ b/tests/unit/test_bot/test_middleware.py @@ -35,6 +35,8 @@ def mock_settings(): settings.enable_api_server = False settings.enable_scheduler = False settings.approved_directory = "/tmp/test" + settings.chunk_buffer_timeout = 0.5 + settings.chunk_buffer_threshold = 4000 return settings