diff --git a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py new file mode 100644 index 0000000..4f8248a --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py @@ -0,0 +1,622 @@ +"""Wrapper around LiteLLM's model I/O library.""" + +from __future__ import annotations + +import json +import logging +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolCall, + ToolCallChunk, + ToolMessage, +) +from langchain_core.messages.ai import UsageMetadata +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, +) +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils import get_from_dict_or_env, pre_init +from langchain_core.utils.function_calling import convert_to_openai_tool +from litellm.types.utils import Delta +from litellm.utils import get_valid_models +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class ChatLiteLLMException(Exception): + """Error with the `LiteLLM I/O` library""" + + +def _create_retry_decorator( + llm: ChatLiteLLM, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + import litellm + + errors = [ + litellm.Timeout, + litellm.APIError, + litellm.APIConnectionError, + litellm.RateLimitError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + # Fix for azure + # Also OpenAI returns None for tool invocations + content = _dict.get("content", "") or "" + + additional_kwargs = {} + if _dict.get("function_call"): + additional_kwargs["function_call"] = dict(_dict["function_call"]) + + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + elif role == "tool": + return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_delta_to_message_chunk( + delta: Union[Delta, Dict[str, Any]], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + # Handle both Delta objects and dicts + if isinstance(delta, dict): + role = delta.get("role") + content = delta.get("content") or "" + function_call = delta.get("function_call") + raw_tool_calls = delta.get("tool_calls") + reasoning_content = delta.get("reasoning_content") + else: + role = delta.role + content = delta.content or "" + function_call = delta.function_call + raw_tool_calls = delta.tool_calls + reasoning_content = getattr(delta, "reasoning_content", None) + + if function_call: + additional_kwargs = {"function_call": dict(function_call)} + # The hasattr check is necessary because litellm explicitly deletes the + # `reasoning_content` attribute when it is absent to comply with the OpenAI API. + # This ensures that the code gracefully handles cases where the attribute is + # missing, avoiding potential errors or non-compliance with the API. + elif reasoning_content: + additional_kwargs = {"reasoning_content": reasoning_content} + else: + additional_kwargs = {} + + tool_call_chunks = [] + if raw_tool_calls: + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + ToolCallChunk( + name=rtc["function"]["name"] + if isinstance(rtc, dict) + else rtc.function.name, + args=rtc["function"]["arguments"] + if isinstance(rtc, dict) + else rtc.function.arguments, + id=rtc["id"] if isinstance(rtc, dict) else rtc.id, + index=rtc["index"] if isinstance(rtc, dict) else rtc.index, + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + ) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + if isinstance(delta, dict): + func_args = function_call.get("arguments", "") if function_call else "" + func_name = function_call.get("name", "") if function_call else "" + else: + func_args = delta.function_call.arguments if function_call else "" + func_name = delta.function_call.name if function_call else "" + return FunctionMessageChunk(content=func_args, name=func_name) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] + else: + return default_class(content=content) # type: ignore[call-arg] + + +def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] = {"content": message.content} + if isinstance(message, ChatMessage): + message_dict["role"] = message.role + elif isinstance(message, HumanMessage): + message_dict["role"] = "user" + elif isinstance(message, AIMessage): + message_dict["role"] = "assistant" + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + if message.tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + ] + elif "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + elif isinstance(message, SystemMessage): + message_dict["role"] = "system" + elif isinstance(message, FunctionMessage): + message_dict["role"] = "function" + message_dict["name"] = message.name + elif isinstance(message, ToolMessage): + message_dict["role"] = "tool" + message_dict["tool_call_id"] = message.tool_call_id + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +_OPENAI_MODELS = get_valid_models(custom_llm_provider="openai") + + +class ChatLiteLLM(BaseChatModel): + """Chat model that uses the LiteLLM API.""" + + client: Any = None #: :meta private: + model: str = "gpt-3.5-turbo" + model_name: Optional[str] = None + stream_options: Optional[Dict[str, Any]] = Field( + default_factory=lambda: {"include_usage": True} + ) + """Model name to use.""" + openai_api_key: Optional[str] = None + azure_api_key: Optional[str] = None + anthropic_api_key: Optional[str] = None + replicate_api_key: Optional[str] = None + cohere_api_key: Optional[str] = None + openrouter_api_key: Optional[str] = None + api_key: Optional[str] = None + streaming: bool = False + api_base: Optional[str] = None + organization: Optional[str] = None + custom_llm_provider: Optional[str] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + temperature: Optional[float] = None + """Run inference with this temperature. Must be in the closed + interval [0.0, 2.0].""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" + top_p: Optional[float] = None + """Decode using nucleus sampling: consider the smallest set of tokens whose + probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + top_k: Optional[int] = None + """Decode using top-k sampling: consider the set of top_k most probable tokens. + Must be positive.""" + n: Optional[int] = None + """Number of chat completions to generate for each prompt. Note that the API may + not return the full n completions if duplicates are generated.""" + max_tokens: Optional[int] = None + + max_retries: int = 1 + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name + return { + "model": set_model_value, + "force_timeout": self.request_timeout, + "max_tokens": self.max_tokens, + "stream": self.streaming, + "n": self.n, + "temperature": self.temperature, + "custom_llm_provider": self.custom_llm_provider, + **self.model_kwargs, + } + + @property + def _client_params(self) -> Dict[str, Any]: + """Get the parameters used for the openai client.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name + self.client.api_base = self.api_base + self.client.api_key = self.api_key + for named_api_key in [ + "openai_api_key", + "azure_api_key", + "anthropic_api_key", + "replicate_api_key", + "cohere_api_key", + "openrouter_api_key", + ]: + if api_key_value := getattr(self, named_api_key): + setattr( + self.client, + named_api_key.replace("_api_key", "_key"), + api_key_value, + ) + self.client.organization = self.organization + creds: Dict[str, Any] = { + "model": set_model_value, + "force_timeout": self.request_timeout, + "api_base": self.api_base, + } + return {**self._default_params, **creds} + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.completion(**kwargs) + + return _completion_with_retry(**kwargs) + + + async def acompletion_with_retry( + self, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any + ) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await self.client.acompletion(**kwargs) + + return await _completion_with_retry(**kwargs) + + @pre_init + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists, temperature, top_p, and top_k.""" + try: + import litellm + except ImportError: + raise ChatLiteLLMException( + "Could not import litellm python package. " + "Please install it with `pip install litellm`" + ) + + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY", default="" + ) + values["azure_api_key"] = get_from_dict_or_env( + values, "azure_api_key", "AZURE_API_KEY", default="" + ) + values["anthropic_api_key"] = get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY", default="" + ) + values["replicate_api_key"] = get_from_dict_or_env( + values, "replicate_api_key", "REPLICATE_API_KEY", default="" + ) + values["openrouter_api_key"] = get_from_dict_or_env( + values, "openrouter_api_key", "OPENROUTER_API_KEY", default="" + ) + values["cohere_api_key"] = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY", default="" + ) + values["huggingface_api_key"] = get_from_dict_or_env( + values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default="" + ) + values["together_ai_api_key"] = get_from_dict_or_env( + values, "together_ai_api_key", "TOGETHERAI_API_KEY", default="" + ) + values["client"] = litellm + + if values["temperature"] is not None and not 0 <= values["temperature"] <= 2: + raise ValueError("temperature must be in the range [0.0, 2.0]") + + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if values["top_k"] is not None and values["top_k"] <= 0: + raise ValueError("top_k must be positive") + + return values + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + token_usage = response.get("usage", {}) + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + if isinstance(message, AIMessage): + message.response_metadata = { + "model_name": self.model_name or self.model + } + message.usage_metadata = _create_usage_metadata(token_usage) + gen = ChatGeneration( + message=message, + generation_info=dict(finish_reason=res.get("finish_reason")), + ) + generations.append(gen) + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name + llm_output = {"token_usage": token_usage, "model": set_model_value} + return ChatResult(generations=generations, llm_output=llm_output) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = self._client_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + params["stream_options"] = self.stream_options + default_chunk_class = AIMessageChunk + for chunk in self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ): + usage_metadata = None + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if "usage" in chunk and chunk["usage"]: + usage_metadata = _create_usage_metadata(chunk["usage"]) + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + if usage_metadata and isinstance(chunk, AIMessageChunk): + chunk.usage_metadata = usage_metadata + + default_chunk_class = chunk.__class__ + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + params["stream_options"] = self.stream_options + default_chunk_class = AIMessageChunk + async for chunk in await self.acompletion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ): + usage_metadata = None + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if "usage" in chunk and chunk["usage"]: + usage_metadata = _create_usage_metadata(chunk["usage"]) + if len(chunk["choices"]) == 0: + continue + delta = chunk["choices"][0]["delta"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + if usage_metadata and isinstance(chunk, AIMessageChunk): + chunk.usage_metadata = usage_metadata + default_chunk_class = chunk.__class__ + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = await self.acompletion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + LiteLLM expects tools argument in OpenAI format. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. Options are: + - str of the form ``"<>"``: calls <> tool. + - ``"auto"``: + automatically selects a tool (including no tool). + - ``"none"``: + does not call a tool. + - ``"any"`` or ``"required"`` or ``True``: + forces least one tool to be called. + - dict of the form: + ``{"type": "function", "function": {"name": <>}}`` + - ``False`` or ``None``: no effect + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + + # In case of openai if tool_choice is `any` or if bool has been provided we + # change it to `required` as that is supported by openai. + if ( + (self.model is not None and "azure" in self.model) + or (self.model_name is not None and "azure" in self.model_name) + or (self.model is not None and self.model in _OPENAI_MODELS) + or (self.model_name is not None and self.model_name in _OPENAI_MODELS) + ) and (tool_choice == "any" or isinstance(tool_choice, bool)): + tool_choice = "required" + # If tool_choice is bool apart from openai we make it `any` + elif isinstance(tool_choice, bool): + tool_choice = "any" + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs) + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + set_model_value = self.model + if self.model_name is not None: + set_model_value = self.model_name + return { + "model": set_model_value, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "n": self.n, + } + + @property + def _llm_type(self) -> str: + return "litellm-chat" + + +def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata: + input_tokens = token_usage.get("prompt_tokens", 0) + output_tokens = token_usage.get("completion_tokens", 0) + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) diff --git a/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py index 7a49a89..2f8b6e7 100644 --- a/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py +++ b/jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py @@ -1,14 +1,96 @@ -from typing import Any, Optional - -from jupyterlab_chat.models import Message -from litellm import acompletion +import os +from typing import Any, Callable +import aiosqlite from jupyter_ai_persona_manager import BasePersona, PersonaDefaults -from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME +from jupyter_core.paths import jupyter_data_dir +from jupyterlab_chat.models import Message +from langchain.agents import create_agent +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware +from langchain.agents.middleware.shell_tool import ShellToolMiddleware +from langchain.messages import ToolMessage +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver +from langgraph.types import Command + +from .chat_models import ChatLiteLLM from .prompt_template import ( JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from .toolkits.notebook import toolkit as nb_toolkit +from .toolkits.jupyterlab import toolkit as jlab_toolkit + +MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite") + + +def format_tool_args_compact(args_dict, threshold=25): + """ + Create a more compact string representation of tool call args. + Each key-value pair is on its own line for better readability. + + Args: + args_dict (dict): Dictionary of tool arguments + threshold (int): Maximum number of lines before truncation (default: 25) + + Returns: + str: Formatted string representation of arguments + """ + if not args_dict: + return "{}" + + formatted_pairs = [] + + for key, value in args_dict.items(): + value_str = str(value) + lines = value_str.split('\n') + + if len(lines) <= threshold: + if len(lines) == 1 and len(value_str) > 80: + # Single long line - truncate + truncated = value_str[:77] + "..." + formatted_pairs.append(f" {key}: {truncated}") + else: + # Add indentation for multi-line values + if len(lines) > 1: + indented_value = '\n '.join([''] + lines) + formatted_pairs.append(f" {key}:{indented_value}") + else: + formatted_pairs.append(f" {key}: {value_str}") + else: + # Truncate and add summary + truncated_lines = lines[:threshold] + remaining_lines = len(lines) - threshold + indented_value = '\n '.join([''] + truncated_lines) + formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]") + + return "{\n" + ",\n".join(formatted_pairs) + "\n}" + + +class ToolMonitoringMiddleware(AgentMiddleware): + def __init__(self, *, persona: BasePersona): + self.stream_message = persona.stream_message + self.log = persona.log + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + args = format_tool_args_compact(request.tool_call['args']) + self.log.info(f"{request.tool_call['name']}({args})") + + try: + result = await handler(request) + self.log.info(f"{request.tool_call['name']} Done!") + return result + except Exception as e: + self.log.info(f"{request.tool_call['name']} failed: {e}") + return ToolMessage( + tool_call_id=request.tool_call["id"], status="error", content=f"{e}" + ) class JupyternautPersona(BasePersona): @@ -28,11 +110,45 @@ def defaults(self): system_prompt="...", ) + async def get_memory_store(self): + if not hasattr(self, "_memory_store"): + conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False) + self._memory_store = AsyncSqliteSaver(conn) + return self._memory_store + + def get_tools(self): + tools = nb_toolkit + tools += jlab_toolkit + return nb_toolkit + + async def get_agent(self, model_id: str, model_args, system_prompt: str): + model = ChatLiteLLM(**model_args, model_id=model_id, streaming=True) + memory_store = await self.get_memory_store() + + if not hasattr(self, "search_tool"): + self.search_tool = FilesystemFileSearchMiddleware( + root_path=self.parent.root_dir + ) + if not hasattr(self, "shell_tool"): + self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir) + if not hasattr(self, "tool_call_handler"): + self.tool_call_handler = ToolMonitoringMiddleware( + persona=self + ) + + return create_agent( + model, + system_prompt=system_prompt, + checkpointer=memory_store, + tools=self.get_tools(), # notebook and jlab tools + middleware=[self.shell_tool, self.tool_call_handler], + ) + async def process_message(self, message: Message) -> None: - if not hasattr(self, 'config_manager'): + if not hasattr(self, "config_manager"): self.send_message( "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n", - "Please make sure to first install that package in your environment & restart the server." + "Please make sure to first install that package in your environment & restart the server.", ) if not self.config_manager.chat_model: self.send_message( @@ -43,28 +159,34 @@ async def process_message(self, message: Message) -> None: model_id = self.config_manager.chat_model model_args = self.config_manager.chat_model_args - context_as_messages = self.get_context_as_messages(model_id, message) - response_aiter = await acompletion( - **model_args, - model=model_id, - messages=[ - *context_as_messages, - { - "role": "user", - "content": message.body, - }, - ], - stream=True, + system_prompt = self.get_system_prompt(model_id=model_id, message=message) + agent = await self.get_agent( + model_id=model_id, model_args=model_args, system_prompt=system_prompt ) + async def create_aiter(): + async for token, metadata in agent.astream( + {"messages": [{"role": "user", "content": message.body}]}, + {"configurable": {"thread_id": self.ychat.get_id()}}, + stream_mode="messages", + ): + node = metadata["langgraph_node"] + content_blocks = token.content_blocks + if ( + node == "model" + and content_blocks + ): + if token.text: + yield token.text + + response_aiter = create_aiter() await self.stream_message(response_aiter) - def get_context_as_messages( + def get_system_prompt( self, model_id: str, message: Message ) -> list[dict[str, Any]]: """ - Returns the current context, including attachments and recent messages, - as a list of messages accepted by `litellm.acompletion()`. + Returns the system prompt, including attachments as a string. """ system_msg_args = JupyternautSystemPromptArgs( model_id=model_id, @@ -72,36 +194,9 @@ def get_context_as_messages( context=self.process_attachments(message), ).model_dump() - system_msg = { - "role": "system", - "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), - } - - context_as_messages = [system_msg, *self._get_history_as_messages()] - return context_as_messages - - def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: - """ - Returns the current history as a list of messages accepted by - `litellm.acompletion()`. - """ - # TODO: consider bounding history based on message size (e.g. total - # char/token count) instead of message count. - all_messages = self.ychat.get_messages() - - # gather last k * 2 messages and return - # we exclude the last message since that is the human message just - # submitted by a user. - start_idx = 0 if k is None else -2 * k - 1 - recent_messages: list[Message] = all_messages[start_idx:-1] - - history: list[dict[str, Any]] = [] - for msg in recent_messages: - role = ( - "assistant" - if msg.sender.startswith("jupyter-ai-personas::") - else "system" if msg.sender == SYSTEM_USERNAME else "user" - ) - history.append({"role": role, "content": msg.body}) + return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args) - return history + def shutdown(self): + if hasattr(self,"_memory_store"): + self.parent.event_loop.create_task(self._memory_store.conn.close()) + super().shutdown() diff --git a/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py index 05cb7b9..7d655e6 100644 --- a/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py +++ b/jupyter_ai_jupyternaut/jupyternaut/prompt_template.py @@ -30,6 +30,14 @@ - Example of a correct response: `You have \\(\\$80\\) remaining.` +If the user's request involves writing to a file, don't use fenced code blocks, write the content directly. + +If the request requires using the add_cell or edit_cell to add code to a notebook code cell, don't use fenced code block. + +If the request requires adding markdown to a notebook markdown cell, don't use markdown code block. + +Don't echo contents back to user after reading files. Rather use that information to fulfill user's request. + You will receive any provided context and a relevant portion of the chat history. The user's request is located at the last message. Please fulfill the user's request to the best of your ability. diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/__init__.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/code_execution.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/code_execution.py new file mode 100644 index 0000000..5986036 --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/toolkits/code_execution.py @@ -0,0 +1,41 @@ +"""Tools that provide code execution features""" + +import asyncio +import shlex +from typing import Optional + + +async def bash(command: str, timeout: Optional[int] = None) -> str: + """Executes a bash command and returns the result + + Args: + command: The bash command to execute + timeout: Optional timeout in seconds + + Returns: + The command output (stdout and stderr combined) + """ + + proc = await asyncio.create_subprocess_exec( + *shlex.split(command), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) + output = stdout.decode("utf-8") + error = stderr.decode("utf-8") + + if proc.returncode != 0: + if error: + return f"Error: {error}" + return f"Command failed with exit code {proc.returncode}" + + return output if output else "Command executed successfully with no output." + except asyncio.TimeoutError: + proc.kill() + return f"Command timed out after {timeout} seconds" + + +toolkit = [bash] diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/filesystem.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/filesystem.py new file mode 100644 index 0000000..c127862 --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/toolkits/filesystem.py @@ -0,0 +1,365 @@ +"""Tools that provide file system functionality""" + +import asyncio +import fnmatch +import glob as glob_module +import os +from typing import List, Optional + +from .utils import normalize_filepath + + +def read(file_path: str, offset: Optional[int] = None, limit: Optional[int] = None) -> str: + """Reads a file from the local filesystem + + Args: + file_path: The absolute path to the file to read + offset: The line number to start reading from (optional) + limit: The number of lines to read (optional) + + Returns: + The contents of the file, potentially with line numbers + """ + try: + file_path = normalize_filepath(file_path) + if not os.path.exists(file_path): + return f"Error: File not found: {file_path}" + + if not os.path.isfile(file_path): + return f"Error: Not a file: {file_path}" + + content = _read_file_content(file_path, offset, limit) + return content + except Exception as e: + return f"Error: Failed to read file: {str(e)}" + + +def _read_file_content( + file_path: str, offset: Optional[int] = None, limit: Optional[int] = None +) -> str: + """Helper function to read file content in a separate thread""" + with open(file_path, "r", encoding="utf-8") as f: + if offset is not None: + # Skip lines until we reach the offset + for _ in range(offset): + line = f.readline() + if not line: + break + + # Read the specified number of lines or all lines if limit is None + if limit is not None: + lines = [f.readline() for _ in range(limit)] + # Filter out None values in case we hit EOF + lines = [line for line in lines if line] + else: + lines = f.readlines() + + # Add line numbers (starting from offset+1 if offset is provided) + start_line = (offset or 0) + 1 + numbered_lines = [f"{i}→{line}" for i, line in enumerate(lines, start=start_line)] + + return "".join(numbered_lines) + + +def write(file_path: str, content: str) -> str: + """Writes content to a file on the local filesystem + + Args: + file_path: + The absolute path to the file to write + content: + The content to write to the file + + Returns: + A success message or error message + """ + try: + file_path = normalize_filepath(file_path) + # Ensure the directory exists + directory = os.path.dirname(file_path) + if directory and not os.path.exists(directory): + os.makedirs(directory) + + _write_file_content(file_path, content) + return f"File written successfully at: {file_path}" + except Exception as e: + return f"Error: Failed to write file: {str(e)}" + + +def _write_file_content(file_path: str, content: str) -> None: + """Helper function to write file content in a separate thread""" + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + +def edit(file_path: str, old_string: str, new_string: str, replace_all: bool = False) -> str: + """Performs string replacement in a file + + Args: + file_path: + The absolute path to the file to modify + old_string: + The text to replace + new_string: + The text to replace it with + replace_all: + Replace all occurrences of old_string (default False) + + Returns: + A success message or error message + """ + try: + file_path = normalize_filepath(file_path) + if not os.path.exists(file_path): + return f"Error: File not found: {file_path}" + + if not os.path.isfile(file_path): + return f"Error: Not a file: {file_path}" + + # Read the file content + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Check if old_string exists in the file + if old_string not in content: + return "Error: String to replace not found in file" + + # Perform the replacement + if replace_all: + new_content = content.replace(old_string, new_string) + else: + # Replace only the first occurrence + new_content = content.replace(old_string, new_string, 1) + + # If nothing changed, old and new strings might be identical + if new_content == content: + return "Error: No changes made. Old string and new string might be identical" + + # Write the updated content back to the file + _write_file_content(file_path, new_content) + + return f"File {file_path} has been updated successfully" + except Exception as e: + return f"Error: Failed to edit file: {str(e)}" + + +async def search_and_replace( + file_path: str, pattern: str, replacement: str, replace_all: bool = False +) -> str: + """Performs pattern search and replace in a file. + + Args: + file_path: + The absolute path to the file to modify + pattern: + The pattern to search for (supports sed syntax) + replacement: + The replacement text + replace_all: + Replace all occurrences of pattern (default False) + + Returns: + A success message or error message + """ + try: + file_path = normalize_filepath(file_path) + if not os.path.exists(file_path): + return f"Error: File not found: {file_path}" + + if not os.path.isfile(file_path): + return f"Error: Not a file: {file_path}" + + # Build the sed command + sed_cmd = ["sed"] + + # -i option for in-place editing (macOS requires an extension) + if os.name == "posix" and "darwin" in os.uname().sysname.lower(): + sed_cmd.extend(["-i", ""]) + else: + sed_cmd.append("-i") + + # Add the search and replace expression + expression = f"s/{pattern}/{replacement}/" + if replace_all: + expression += "g" + + sed_cmd.extend([expression, file_path]) + + # Run sed command + proc = await asyncio.create_subprocess_exec( + *sed_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + _, stderr = await proc.communicate() + + if proc.returncode != 0: + if stderr: + error = stderr.decode("utf-8") + return f"Error: sed command failed: {error}" + return f"Error: sed command failed with return code {proc.returncode}" + + return f"File {file_path} has been updated successfully" + except Exception as e: + return f"Error: Failed to search and edit file: {str(e)}" + + +async def glob(pattern: str, path: Optional[str] = None) -> str: + """Searches for files that matches the glob pattern + + Args: + pattern: + The glob pattern to match files against + path: + The directory to search in (optional, defaults to current directory) + + Returns: + A list of matching file paths sorted by modification time + """ + try: + search_path = normalize_filepath(path) if path else os.getcwd() + if not os.path.exists(search_path): + return f"Error: Path not found: {search_path}" + + # Use asyncio.to_thread to run glob in a separate thread + matching_files = await asyncio.to_thread(_glob_search, search_path, pattern) + + if not matching_files: + return "No matching files found" + + # Sort files by modification time (most recent first) + matching_files.sort(key=lambda f: os.path.getmtime(f), reverse=True) + matching_files = [str(f) for f in matching_files] + + return "\n".join(matching_files) + except Exception as e: + return f"Error: Failed to perform glob search: {str(e)}" + + +def _glob_search(search_path: str, pattern: str) -> List[str]: + """Helper function to perform glob search in a separate thread""" + # Construct the full pattern + if not search_path.endswith(os.sep) and not pattern.startswith(os.sep): + full_pattern = os.path.join(search_path, pattern) + else: + full_pattern = search_path + pattern + + # Use glob.glob for the actual search + return glob_module.glob(full_pattern, recursive=True) + + +async def grep( + pattern: str, include: Optional[str] = None, path: Optional[str] = None +) -> List[str]: + """Fast content search using regular expressions + + Args: + pattern: + The regular expression pattern to search for in file contents + include: + File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}") (optional) + path: + The directory to search in (optional, defaults to current directory) + + Returns: + A list of file paths with at least one match + """ + try: + search_path = normalize_filepath(path) if path else os.getcwd() + if not os.path.exists(search_path): + return [f"Error: Path not found: {search_path}"] + + # Prepare the command arguments for running grep + command_args = ["grep", "-l", "--include", include or "*", "-r", pattern, search_path] + + # Run grep command asynchronously + proc = await asyncio.create_subprocess_exec( + *command_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await proc.communicate() + + if ( + proc.returncode != 0 and proc.returncode != 1 + ): # 1 means no matches, which is not an error + if stderr: + error = stderr.decode("utf-8") + return [f"Error: Grep command failed: {error}"] + return [f"Error: Grep command failed with return code {proc.returncode}"] + + # Parse the output and get the list of files + matching_files = stdout.decode("utf-8").strip().split("\n") + + # Filter out empty entries + matching_files = [f for f in matching_files if f] + + if not matching_files: + return [] + + # Sort files by modification time (most recent first) + matching_files.sort(key=lambda f: os.path.getmtime(f), reverse=True) + + return matching_files + except Exception as e: + return [f"Error: Failed to perform grep search: {str(e)}"] + + +async def ls(path: str, ignore: Optional[List[str]] = None) -> str: + """Lists files and directories in a given path + + Args: + path: + The absolute path to the directory to list + ignore: + List of glob patterns to ignore (optional) + + Returns: + A list of files and directories in the given path + """ + try: + path = normalize_filepath(path) + if not os.path.exists(path): + return f"Error: Path not found: {path}" + + if not os.path.isdir(path): + return f"Error: Not a directory: {path}" + + # Get all files and directories in the given path + items = await asyncio.to_thread(os.listdir, path) + + # Apply ignore patterns if provided + if ignore: + filtered_items = [] + for item in items: + item_path = os.path.join(path, item) + should_ignore = False + + for pattern in ignore: + if fnmatch.fnmatch(item, pattern) or fnmatch.fnmatch(item_path, pattern): + should_ignore = True + break + + if not should_ignore: + filtered_items.append(item) + + items = filtered_items + + # Construct full paths + full_paths = [os.path.join(path, item) for item in items] + + # Sort by type (directories first) and then by name + full_paths.sort(key=lambda p: (0 if os.path.isdir(p) else 1, p.lower())) + + return "\n".join(full_paths) + except Exception as e: + return f"Error: Failed to list directory: {str(e)}" + + +toolkit = [ + read, + edit, + write, + search_and_replace, + glob, + grep, + ls +] \ No newline at end of file diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/jupyterlab.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/jupyterlab.py new file mode 100644 index 0000000..7c782fa --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/toolkits/jupyterlab.py @@ -0,0 +1,16 @@ +from jupyterlab_commands_toolkit.tools import execute_command + + +async def open_file(file_path: str): + """ + Opens a file in JupyterLab main area + """ + await execute_command("docmanager:open", {"path": file_path}) + +async def run_all_cells(): + """ + Runs all cells in the currently active Jupyter notebook + """ + return await execute_command("notebook:run-all-cells") + +toolkit = [open_file, run_all_cells] \ No newline at end of file diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/notebook.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/notebook.py new file mode 100644 index 0000000..5edc234 --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/toolkits/notebook.py @@ -0,0 +1,1177 @@ +import asyncio +import difflib +import json +import os +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import nbformat +from jupyter_ydoc import YNotebook +from pycrdt import Assoc, Text + +from .utils import ( + cell_to_md, + get_file_id, + get_jupyter_ydoc, + normalize_filepath, + notebook_json_to_md, +) + + +def _is_uuid_like(value: str) -> bool: + """Check if a string looks like a UUID v4""" + if not isinstance(value, str): + return False + # UUID v4 pattern: 8-4-4-4-12 hexadecimal characters + uuid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + return bool(re.match(uuid_pattern, value, re.IGNORECASE)) + + +def _is_index_like(value: str) -> bool: + """Check if a string looks like a numeric index""" + if not isinstance(value, str): + return False + try: + int(value) + return True + except ValueError: + return False + + +async def _resolve_cell_id(file_path: str, cell_id_or_index: str) -> str: + """ + Resolve a cell_id parameter that might be either a UUID or an index. + If it's an index, convert it to the actual cell_id. + """ + if _is_uuid_like(cell_id_or_index): + return cell_id_or_index + elif _is_index_like(cell_id_or_index): + index = int(cell_id_or_index) + try: + actual_cell_id = await get_cell_id_from_index(file_path, index) + return actual_cell_id + except Exception as e: + raise ValueError(f"Invalid cell index {index}: {str(e)}") + else: + # Assume it's a cell_id and let the downstream function handle validation + return cell_id_or_index + + +async def read_notebook(file_path: str, include_outputs=False) -> str: + """Returns the complete notebook content as markdown string. + + This function reads a Jupyter notebook file and converts its content to a markdown string. + It uses the read_notebook_json function to read the notebook file and then converts + the resulting JSON to markdown. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + include_outputs: + If True, cell outputs will be included in the markdown. Default is False. + + Returns: + The notebook content as a markdown string. + """ + try: + file_path = normalize_filepath(file_path) + notebook_dict = await read_notebook_json(file_path) + notebook_md = notebook_json_to_md(notebook_dict, include_outputs=include_outputs) + return notebook_md + except Exception: + raise + +def clean_text(text: Union[str, list, None]) -> Optional[str]: + """ + Clean and format text output (equivalent to kC0). + + Args: + text: Text data that might be string, list, or None + + Returns: + Cleaned text string or None + """ + if text is None: + return None + + if isinstance(text, list): + return ''.join(str(item) for item in text) + + return str(text) + +def process_notebook_output(output_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Process a Jupyter notebook cell output into a standardized format. + + Args: + output_data: Raw output data from notebook cell + + Returns: + Processed output dictionary with standardized format + """ + output_type = output_data.get('output_type') + + if output_type == "stream": + return { + 'output_type': output_type, + 'text': clean_text(output_data.get('text', '')) + } + + elif output_type in ["execute_result", "display_data"]: + data = output_data.get('data', {}) + return { + 'output_type': output_type, + 'text': clean_text(data.get('text/plain')), + 'image': extract_image_data(data) if data else None, + } + + elif output_type == "error": + error_name = output_data.get('ename', '') + error_value = output_data.get('evalue', '') + traceback = output_data.get('traceback', []) + + error_text = f"{error_name}: {error_value}\n{chr(10).join(traceback)}" + + return { + 'output_type': output_type, + 'text': clean_text(error_text), + } + + # Handle unknown output types + return output_data + +def extract_image_data(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Extract image data from notebook output data (equivalent to qh6). + + Args: + data: Output data dictionary that may contain various MIME types + + Returns: + Extracted image data or None + """ + # Common image MIME types in Jupyter notebooks + image_mime_types = [ + 'image/png', + 'image/jpeg', + 'image/jpg', + 'image/gif', + 'image/svg+xml' + ] + + for mime_type in image_mime_types: + if mime_type in data: + return { + 'mime_type': mime_type, + 'data': data[mime_type] + } + + return None + +def format_notebook_cell(cell_data: Dict[str, Any], cell_index: int, language: str, include_full_outputs: bool = False) -> Dict[str, Any]: + """ + Format a Jupyter notebook cell into a standardized format. + + Args: + cell_data: Raw cell data from notebook JSON + cell_index: Index of the cell in the notebook + language: Programming language of the notebook + include_full_outputs: Whether to include full outputs or truncate large ones + """ + cell_id = cell_data.get('id', f'cell-{cell_index}') + + formatted_cell = { + 'cellType': cell_data['cell_type'], + 'source': ''.join(cell_data['source']) if isinstance(cell_data['source'], list) else cell_data['source'], + 'execution_count': cell_data.get('execution_count') if cell_data['cell_type'] == 'code' else None, + 'cell_id': cell_id, + } + + # Add language for code cells + if cell_data['cell_type'] == 'code': + formatted_cell['language'] = language + + # Handle outputs for code cells + if cell_data['cell_type'] == 'code' and cell_data.get('outputs'): + processed_outputs = [process_notebook_output(output) for output in cell_data['outputs']] + + # Truncate large outputs unless specifically requested to include full outputs + if not include_full_outputs and len(json.dumps(processed_outputs)) > 10000: + formatted_cell['outputs'] = [ + { + 'output_type': 'stream', + 'text': f'Outputs are too large to include. Use command with: cat | jq \'.cells[{cell_index}].outputs\'', + } + ] + else: + formatted_cell['outputs'] = processed_outputs + + return formatted_cell + +async def read_notebook_cells(notebook_path: str, specific_cell_id: Optional[str] = None) -> List[Dict[str, Any]]: + """ + Read and process cells from a Jupyter notebook file. + + Args: + notebook_path: Path to the notebook file + specific_cell_id: Optional cell ID to return only that cell + + Returns: + List of formatted cell dictionaries + + Raises: + FileNotFoundError: If notebook file doesn't exist + ValueError: If specific cell ID is not found + """ + resolved_path = normalize_filepath(notebook_path) + + with open(resolved_path, 'r', encoding='utf-8') as file: + notebook_data = json.load(file) + + language = notebook_data.get('metadata', {}).get('language_info', {}).get('name', 'python') + + # If requesting a specific cell + if specific_cell_id: + target_cell = None + cell_index = -1 + + for i, cell in enumerate(notebook_data['cells']): + if cell.get('id') == specific_cell_id: + target_cell = cell + cell_index = i + break + + if target_cell is None: + raise ValueError(f'Cell with ID "{specific_cell_id}" not found in notebook') + + return [format_notebook_cell(target_cell, cell_index, language, include_full_outputs=True)] + + # Return all cells + return [ + format_notebook_cell(cell, index, language, include_full_outputs=False) + for index, cell in enumerate(notebook_data['cells']) + ] + + +async def read_notebook_json(file_path: str) -> Dict[str, Any]: + """Returns the complete notebook content as a JSON dictionary. + + This function reads a Jupyter notebook file and returns its content as a + dictionary representation of the JSON structure. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + + Returns: + A dictionary containing the complete notebook structure. + """ + try: + file_path = normalize_filepath(file_path) + with open(file_path, "r", encoding="utf-8") as f: + notebook_dict = json.load(f) + return notebook_dict + except Exception: + raise + + +async def read_cell(file_path: str, cell_id: str, include_outputs: bool = True) -> str: + """Returns the notebook cell as a markdown string. + + This function reads a specific cell from a Jupyter notebook file and converts + it to a markdown string. It uses the read_cell_json function to read the cell + and then converts it to markdown. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + cell_id: + The UUID of the cell to read, or a numeric index as string. + include_outputs: + If True, cell outputs will be included in the markdown. Default is True. + + Returns: + The cell content as a markdown string. + + Raises: + LookupError: If no cell with the given ID is found. + """ + try: + file_path = normalize_filepath(file_path) + # Resolve cell_id in case it's an index + resolved_cell_id = await _resolve_cell_id(file_path, cell_id) + cell, cell_index = await read_cell_json(file_path, resolved_cell_id) + cell_md = cell_to_md(cell, cell_index) + return cell_md + except Exception: + raise + + +async def read_cell_json(file_path: str, cell_id: str) -> Tuple[Dict[str, Any], int]: + """Returns the notebook cell as a JSON dictionary and its index. + + This function reads a specific cell from a Jupyter notebook file and returns + both the cell content as a dictionary and the cell's index within the notebook. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + cell_id: + The UUID of the cell to read, or a numeric index as string. + + Returns: + A tuple containing: + - The cell as a dictionary + - The index of the cell in the notebook + + Raises: + LookupError: If no cell with the given ID is found. + """ + try: + file_path = normalize_filepath(file_path) + # Resolve cell_id in case it's an index + resolved_cell_id = await _resolve_cell_id(file_path, cell_id) + notebook_json = await read_notebook_json(file_path) + cell_index = _get_cell_index_from_id_json(notebook_json, resolved_cell_id) + + if cell_index is not None and 0 <= cell_index < len(notebook_json["cells"]): + cell = notebook_json["cells"][cell_index] + return cell, cell_index + + raise LookupError(f"No cell found with {cell_id=}") + + except Exception: + raise + + +async def get_cell_id_from_index(file_path: str, cell_index: int) -> str: + """Finds the cell_id of the cell at a specific cell index. + + This function reads a Jupyter notebook file and returns the UUID of the cell + at the specified index position. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + cell_index: + The index of the cell to find the ID for. + + Returns: + The UUID of the cell at the specified index, or None if the index is out of range + or if the cell does not have an ID. + """ + try: + file_path = normalize_filepath(file_path) + cell_id = None + notebook_json = await read_notebook_json(file_path) + cells = notebook_json["cells"] + + if 0 <= cell_index < len(cells): + cell_id = cells[cell_index].get("id") + else: + cell_id = None + + if cell_id is None: + raise ValueError("No cell_id found, use `insert_cell` based on cell index") + + return cell_id + + except Exception: + raise + + +async def add_cell( + file_path: str, + content: str | None = None, + cell_id: str | None = None, + add_above: bool = False, + cell_type: Literal["code", "markdown", "raw"] = "code", +): + """Adds a new cell to the Jupyter notebook above or below a specified cell. + + This function adds a new cell to a Jupyter notebook. It first attempts to use + the in-memory YDoc representation if the notebook is currently active. If the + notebook is not active, it falls back to using the filesystem to read, modify, + and write the notebook file directly. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + content: + The content of the new cell. If None, an empty cell is created. + cell_id: + The UUID of the cell to add relative to, or a numeric index as string. If None, + the cell is added at the end of the notebook. + add_above: + If True, the cell is added above the specified cell. If False, + it's added below the specified cell. + cell_type: + The type of cell to add ("code", "markdown", "raw"). + + Returns: + None + """ + try: + file_path = normalize_filepath(file_path) + # Resolve cell_id in case it's an index + resolved_cell_id = await _resolve_cell_id(file_path, cell_id) if cell_id else None + + file_id = await get_file_id(file_path) + ydoc: YNotebook = await get_jupyter_ydoc(file_id) + + if ydoc: + cells_count = ydoc.cell_number + cell_index = ( + _get_cell_index_from_id_ydoc(ydoc, resolved_cell_id) if resolved_cell_id else None + ) + insert_index = _determine_insert_index(cells_count, cell_index, add_above) + + cell = { + "cell_type": cell_type, + "source": "", + } + ycell = ydoc.create_ycell(cell) + if insert_index >= cells_count: + ydoc.ycells.append(ycell) + else: + ydoc.ycells.insert(insert_index, ycell) + await write_to_cell_collaboratively(ydoc, ycell, content or "") + else: + with open(file_path, "r", encoding="utf-8") as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cells_count = len(notebook.cells) + cell_index = ( + _get_cell_index_from_id_nbformat(notebook, resolved_cell_id) + if resolved_cell_id + else None + ) + insert_index = _determine_insert_index(cells_count, cell_index, add_above) + + if cell_type == "code": + notebook.cells.insert(insert_index, nbformat.v4.new_code_cell(source=content or "")) + elif cell_type == "markdown": + notebook.cells.insert( + insert_index, nbformat.v4.new_markdown_cell(source=content or "") + ) + else: + notebook.cells.insert(insert_index, nbformat.v4.new_raw_cell(source=content or "")) + + with open(file_path, "w", encoding="utf-8") as f: + nbformat.write(notebook, f) + + except Exception: + raise + + +async def insert_cell( + file_path: str, + insert_index: int, + content: str | None = None, + cell_type: Literal["code", "markdown", "raw"] = "code", +): + """Inserts a new cell to the Jupyter notebook at the specified cell index. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + insert_index: + The index to insert the cell at. + content: + The content of the new cell. If None, an empty cell is created. + cell_type: + The type of cell to add ("code", "markdown", "raw"). + + Returns: + None + """ + try: + file_path = normalize_filepath(file_path) + file_id = await get_file_id(file_path) + ydoc = await get_jupyter_ydoc(file_id) + + if ydoc: + cells_count = ydoc.cell_number + + cell = { + "cell_type": cell_type, + "source": "", + } + ycell = ydoc.create_ycell(cell) + if insert_index >= cells_count: + ydoc.ycells.append(ycell) + else: + ydoc.ycells.insert(insert_index, ycell) + await write_to_cell_collaboratively(ydoc, ycell, content or "") + else: + with open(file_path, "r", encoding="utf-8") as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cells_count = len(notebook.cells) + + if cell_type == "code": + notebook.cells.insert(insert_index, nbformat.v4.new_code_cell(source=content or "")) + elif cell_type == "markdown": + notebook.cells.insert( + insert_index, nbformat.v4.new_markdown_cell(source=content or "") + ) + else: + notebook.cells.insert(insert_index, nbformat.v4.new_raw_cell(source=content or "")) + + with open(file_path, "w", encoding="utf-8") as f: + nbformat.write(notebook, f) + + except Exception: + raise + + +async def delete_cell(file_path: str, cell_id: str): + """Removes a notebook cell with the specified cell ID. + + This function deletes a cell from a Jupyter notebook. It first attempts to use + the in-memory YDoc representation if the notebook is currently active. If the + notebook is not active, it falls back to using the filesystem to read, modify, + and write the notebook file directly using nbformat. + + Args: + file_path: The relative path to the notebook file on the filesystem. + cell_id: The UUID of the cell to delete, or a numeric index as string. + + Returns: + None + """ + try: + file_path = normalize_filepath(file_path) + # Resolve cell_id in case it's an index + resolved_cell_id = await _resolve_cell_id(file_path, cell_id) + + file_id = await get_file_id(file_path) + ydoc = await get_jupyter_ydoc(file_id) + + if ydoc: + cell_index = _get_cell_index_from_id_ydoc(ydoc, resolved_cell_id) + if cell_index is not None and 0 <= cell_index < len(ydoc.ycells): + del ydoc.ycells[cell_index] + else: + pass # Cell not found in ydoc + else: + with open(file_path, "r", encoding="utf-8") as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_index = _get_cell_index_from_id_nbformat(notebook, resolved_cell_id) + if cell_index is not None and 0 <= cell_index < len(notebook.cells): + notebook.cells.pop(cell_index) + with open(file_path, "w", encoding="utf-8") as f: + nbformat.write(notebook, f) + else: + pass # Cell not found in notebook + + if cell_index is None: + raise ValueError(f"Could not find cell index for {cell_id=}") + + except Exception: + raise + + +def get_cursor_details( + cell_source: Text, start_index: int, stop_index: Optional[int] = None +) -> Dict[str, Any]: + """ + Creates cursor details for collaborative notebook cursor positioning. + + This function constructs the cursor details object required by the YNotebook + awareness system to show cursor positions in collaborative editing environments. + It handles both single cursor positions and text selections. + + Args: + cell_source: The YText source object representing the cell content + start_index: The starting position of the cursor (0-based index) + stop_index: The ending position for selections (optional) + + Returns: + dict: Cursor details object with head, anchor, and selection state + + Example: + >>> details = get_cursor_details(cell_source, 10) # Single cursor at position 10 + >>> details = get_cursor_details(cell_source, 5, 15) # Selection from 5 to 15 + """ + # Create sticky index for the head position (where cursor starts) + head_sticky_index = cell_source.sticky_index(start_index, Assoc.BEFORE) + head_sticky_index_data = head_sticky_index.to_json() + + # Initialize cursor details with default values + cursor_details: Dict[str, Any] = {"primary": True, "empty": True} + + # Set the head position (where cursor starts) + cursor_details["head"] = { + "type": head_sticky_index_data["item"], + "tname": None, + "item": head_sticky_index_data["item"], + "assoc": 0, + } + + # By default, anchor is same as head (no selection) + cursor_details["anchor"] = cursor_details["head"] + + # If stop_index is provided, create a selection + if stop_index is not None: + anchor_sticky_index = cell_source.sticky_index(stop_index, Assoc.BEFORE) + anchor_sticky_index_data = anchor_sticky_index.to_json() + cursor_details["anchor"] = { + "type": anchor_sticky_index_data["item"], + "tname": None, + "item": anchor_sticky_index_data["item"], + "assoc": 0, + } + cursor_details["empty"] = False # Not empty when there's a selection + + return cursor_details + + +def set_cursor_in_ynotebook( + ynotebook: YNotebook, cell_source: Text, start_index: int, stop_index: Optional[int] = None +) -> None: + """ + Sets the cursor position in a collaborative notebook environment. + + This function updates the cursor position in the YNotebook awareness system, + which allows other collaborators to see where the cursor is positioned. + It handles both single cursor positions and text selections. + + Args: + ynotebook: The YNotebook instance representing the collaborative notebook + cell_source: The YText source object representing the cell content + start_index: The starting position of the cursor (0-based index) + stop_index: The ending position for selections (optional) + + Returns: + None: This function does not return a value + + Note: + This function silently ignores any errors that occur during cursor setting + to avoid breaking the main collaborative editing operations. + + Example: + >>> set_cursor_in_ynotebook(ynotebook, cell_source, 10) # Set cursor at position 10 + >>> set_cursor_in_ynotebook(ynotebook, cell_source, 5, 15) # Select text from 5 to 15 + """ + try: + # Get cursor details for the specified position/selection + details = get_cursor_details(cell_source, start_index, stop_index=stop_index) + + # Update the awareness system with the cursor position + if ynotebook.awareness: + ynotebook.awareness.set_local_state_field("cursors", [details]) + except Exception: + # Silently ignore cursor setting errors to avoid breaking main operations + # This is intentional - cursor positioning is a visual enhancement, not critical + pass + + +async def write_to_cell_collaboratively( + ynotebook, ycell, content: str, typing_speed: float = 0.1 +) -> bool: + """ + Writes content to a Jupyter notebook cell with collaborative typing simulation. + + This function provides a collaborative writing experience by applying text changes + incrementally with visual feedback. It uses a diff-based approach to compute the + minimal set of changes needed and applies them with cursor positioning and timing + delays to simulate natural typing behavior. + + The function handles three types of operations: + - Delete: Removes text with visual highlighting + - Insert: Adds text word-by-word with typing delays + - Replace: Combines delete and insert operations + + Args: + ynotebook: The YNotebook instance representing the collaborative notebook + ycell: The YCell instance representing the specific cell to modify + content: The new content to write to the cell + typing_speed: Delay in seconds between typing operations (default: 0.1) + + Returns: + bool: True if the operation completed successfully + + Raises: + ValueError: If ynotebook/ycell is None or typing_speed is negative + TypeError: If content is not a string + RuntimeError: If cell content extraction or writing fails + + Example: + >>> # Write with default typing speed + >>> success = await write_to_cell_collaboratively(ynotebook, ycell, "print('Hello')") + >>> + >>> # Write with custom typing speed (faster) + >>> success = await write_to_cell_collaboratively( + ... ynotebook, ycell, "print('World')", typing_speed=0.05 + ... ) + """ + # Input validation + if ynotebook is None: + raise ValueError("ynotebook cannot be None") + if ycell is None: + raise ValueError("ycell cannot be None") + if not isinstance(content, str): + raise TypeError("content must be a string") + if typing_speed < 0: + raise ValueError("typing_speed must be non-negative") + + try: + # Extract current cell content + cell = ycell.to_py() + old_content = cell.get("source", "") + cell_source = ycell["source"] # YText object for collaborative editing + new_content = content + + # Early return if content is unchanged + if old_content == new_content: + return True + + except Exception as e: + raise RuntimeError(f"Failed to extract cell content: {e}") + + try: + # Compute the minimal set of changes needed using difflib + sequence_matcher = difflib.SequenceMatcher(None, old_content, new_content) + cursor_position = 0 + + # Set initial cursor position + _safe_set_cursor(ynotebook, cell_source, cursor_position) + + # Apply each change operation sequentially + for operation, old_start, old_end, new_start, new_end in sequence_matcher.get_opcodes(): + if operation == "equal": + # No changes needed for this segment, just advance cursor + cursor_position += old_end - old_start + + elif operation == "delete": + # Remove text with visual feedback + delete_length = old_end - old_start + await _handle_delete_operation( + ynotebook, cell_source, cursor_position, delete_length, typing_speed + ) + # Cursor stays at same position after deletion + + elif operation == "insert": + # Add text with typing simulation + cursor_position = await _handle_insert_operation( + ynotebook, + cell_source, + cursor_position, + new_content, + new_start, + new_end, + typing_speed, + ) + + elif operation == "replace": + # Combine delete and insert operations + delete_length = old_end - old_start + cursor_position = await _handle_replace_operation( + ynotebook, + cell_source, + cursor_position, + new_content, + delete_length, + new_start, + new_end, + typing_speed, + ) + + # Set final cursor position at the end of the content + _safe_set_cursor(ynotebook, cell_source, cursor_position) + + return True + + except Exception as e: + raise RuntimeError(f"Failed to write cell content collaboratively: {e}") + + +async def _handle_delete_operation( + ynotebook, cell_source, cursor_position: int, delete_length: int, typing_speed: float +) -> None: + """ + Handle deletion of text chunks with visual feedback. + + This function provides visual feedback during deletion by first highlighting + the text to be deleted, then removing it after a delay to simulate natural + deletion behavior in collaborative environments. + + Args: + ynotebook: The YNotebook instance for cursor positioning + cell_source: The YText source object representing the cell content + cursor_position: Current cursor position in the text + delete_length: Number of characters to delete from cursor position + typing_speed: Base delay between operations in seconds + + Returns: + None + """ + # Highlight the text chunk that will be deleted (visual feedback) + _safe_set_cursor(ynotebook, cell_source, cursor_position, cursor_position + delete_length) + await asyncio.sleep(min(0.3, typing_speed * 3)) # Cap highlight duration at 0.3s + + # Perform the actual deletion + del cell_source[cursor_position : cursor_position + delete_length] + await asyncio.sleep(typing_speed) + + +async def _handle_insert_operation( + ynotebook, + cell_source, + cursor_position: int, + new_content: str, + new_start: int, + new_end: int, + typing_speed: float, +) -> int: + """ + Handle insertion of text with word-by-word typing simulation. + + This function simulates natural typing behavior by inserting text word-by-word + with appropriate delays and cursor positioning. It handles both regular text + and whitespace-only content appropriately. + + Args: + ynotebook: The YNotebook instance for cursor positioning + cell_source: The YText source object representing the cell content + cursor_position: Current cursor position in the text + new_content: The complete new content string + new_start: Start index of text to insert in the new content + new_end: End index of text to insert in the new content + typing_speed: Base delay between typing operations in seconds + + Returns: + int: The new cursor position after insertion + """ + text_to_insert = new_content[new_start:new_end] + words = text_to_insert.split() + + # Handle whitespace-only or empty insertions + if not words or text_to_insert.strip() == "": + cell_source.insert(cursor_position, text_to_insert) + cursor_position += len(text_to_insert) + _safe_set_cursor(ynotebook, cell_source, cursor_position) + await asyncio.sleep(typing_speed) + return cursor_position + + # Insert text word-by-word with proper spacing and punctuation + current_pos = 0 + for word in words: + # Find the position of this word in the text + word_start = text_to_insert.find(word, current_pos) + + # Insert any whitespace or punctuation before the word + if word_start > current_pos: + prefix = text_to_insert[current_pos:word_start] + cell_source.insert(cursor_position, prefix) + cursor_position += len(prefix) + + # Insert the word itself + cell_source.insert(cursor_position, word) + cursor_position += len(word) + current_pos = word_start + len(word) + + # Update cursor position and pause for typing effect + _safe_set_cursor(ynotebook, cell_source, cursor_position) + await asyncio.sleep(typing_speed) + + # Insert any remaining text after the last word (punctuation, etc.) + if current_pos < len(text_to_insert): + suffix = text_to_insert[current_pos:] + cell_source.insert(cursor_position, suffix) + cursor_position += len(suffix) + _safe_set_cursor(ynotebook, cell_source, cursor_position) + + return cursor_position + + +async def _handle_replace_operation( + ynotebook, + cell_source, + cursor_position: int, + new_content: str, + delete_length: int, + new_start: int, + new_end: int, + typing_speed: float, +) -> int: + """ + Handle replacement operations by deleting then inserting. + + This function simulates natural text replacement behavior by first deleting + the old text (with visual feedback) and then inserting the new text with + typing simulation. A pause is added between operations to make the replacement + feel more natural. + + Args: + ynotebook: The YNotebook instance for cursor positioning + cell_source: The YText source object representing the cell content + cursor_position: Current cursor position in the text + new_content: The complete new content string + delete_length: Number of characters to delete from cursor position + new_start: Start index of replacement text in the new content + new_end: End index of replacement text in the new content + typing_speed: Base delay between typing operations in seconds + + Returns: + int: The new cursor position after replacement + """ + # First, delete the old text with visual feedback + await _handle_delete_operation( + ynotebook, cell_source, cursor_position, delete_length, typing_speed + ) + + # Brief pause between deletion and insertion for natural feel + await asyncio.sleep(typing_speed * 2) + + # Then, insert the new text with typing simulation + cursor_position = await _handle_insert_operation( + ynotebook, cell_source, cursor_position, new_content, new_start, new_end, typing_speed + ) + + return cursor_position + + +def _safe_set_cursor( + ynotebook: YNotebook, cell_source: Text, cursor_position: int, stop_cursor: Optional[int] = None +) -> None: + """ + Safely set cursor position with error handling. + + This function wraps the cursor positioning logic to prevent errors from + breaking the main collaborative writing operations. Since cursor positioning + is a visual enhancement rather than a core functionality, errors are silently + ignored to maintain robustness. + + Args: + ynotebook: The YNotebook instance for cursor positioning + cell_source: The YText source object representing the cell content + cursor_position: The cursor position to set + stop_cursor: Optional end position for text selections + + Returns: + None + + Note: + This function silently ignores all exceptions to prevent cursor + positioning errors from interfering with the main editing operations. + """ + try: + set_cursor_in_ynotebook(ynotebook, cell_source, cursor_position, stop_cursor) + except Exception: + # Silently ignore cursor setting errors to avoid breaking the main operation + # Cursor positioning is a visual enhancement, not critical functionality + pass + + +async def edit_cell(file_path: str, cell_id: str, content: str) -> None: + """Edits the content of a notebook cell with the specified ID + + This function modifies the content of a cell in a Jupyter notebook. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + cell_id: + The UUID of the cell to edit, or a numeric index as string. + content: + The new content for the cell. If None, the cell content remains unchanged. + + Returns: + None + + Raises: + ValueError: If the cell_id is not found in the notebook. + """ + try: + file_path = normalize_filepath(file_path) + # Resolve cell_id in case it's an index + resolved_cell_id = await _resolve_cell_id(file_path, cell_id) + + file_id = await get_file_id(file_path) + ydoc = await get_jupyter_ydoc(file_id) + + if ydoc: + cell_index = _get_cell_index_from_id_ydoc(ydoc, resolved_cell_id) + if cell_index is not None: + ycell = ydoc._ycells[cell_index] + await write_to_cell_collaboratively(ydoc, ycell, content) + else: + raise ValueError(f"Cell with {cell_id=} not found in notebook") + else: + with open(file_path, "r", encoding="utf-8") as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_index = _get_cell_index_from_id_nbformat(notebook, resolved_cell_id) + if cell_index is not None: + notebook.cells[cell_index].source = content + with open(file_path, "w", encoding="utf-8") as f: + nbformat.write(notebook, f) + else: + raise ValueError(f"Cell with {cell_id=} not found in notebook at {file_path=}") + + except Exception: + raise + + +# Note: This is currently failing with server outputs, use `read_cell` instead +def read_cell_nbformat(file_path: str, cell_id: str) -> Dict[str, Any]: + """Returns the content and metadata of a cell with the specified ID. + + This function reads a specific cell from a Jupyter notebook file using the nbformat + library and returns the cell's content and metadata. + + Note: This function is currently not functioning properly with server outputs. + Use `read_cell` instead. + + Args: + file_path: + The relative path to the notebook file on the filesystem. + cell_id: + The UUID of the cell to read. + + Returns: + The cell as a dictionary containing its content and metadata. + + Raises: + ValueError: If no cell with the given ID is found. + """ + file_path = normalize_filepath(file_path) + with open(file_path, "r", encoding="utf-8") as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_index = _get_cell_index_from_id_nbformat(notebook, cell_id) + if cell_index is not None: + cell = notebook.cells[cell_index] + return cell + else: + raise ValueError(f"Cell with {cell_id=} not found in notebook at {file_path=}") + + +def _get_cell_index_from_id_json(notebook_json, cell_id: str) -> int | None: + """Get cell index from cell_id by notebook json dict. + + Args: + notebook_json: + The notebook as a JSON dictionary. + cell_id: + The UUID of the cell to find. + + Returns: + The index of the cell in the notebook, or None if not found. + """ + for i, cell in enumerate(notebook_json["cells"]): + if "id" in cell and cell["id"] == cell_id: + return i + return None + + +def _get_cell_index_from_id_ydoc(ydoc, cell_id: str) -> int | None: + """Get cell index from cell_id using YDoc interface. + + Args: + ydoc: + The YDoc object representing the notebook. + cell_id: + The UUID of the cell to find. + + Returns: + The index of the cell in the notebook, or None if not found. + """ + try: + cell_index, _ = ydoc.find_cell(cell_id) + return cell_index + except (AttributeError, KeyError): + return None + + +def _get_cell_index_from_id_nbformat(notebook, cell_id: str) -> int | None: + """Get cell index from cell_id using nbformat interface. + + Args: + notebook: + The nbformat notebook object. + cell_id: + The UUID of the cell to find. + + Returns: + The index of the cell in the notebook, or None if not found. + """ + for i, cell in enumerate(notebook.cells): + if hasattr(cell, "id") and cell.id == cell_id: + return i + elif hasattr(cell, "metadata") and cell.metadata.get("id") == cell_id: + return i + return None + + +def _determine_insert_index(cells_count: int, cell_index: Optional[int], add_above: bool) -> int: + """Determine the index where a new cell should be inserted. + + Args: + cells_count: + The total number of cells in the notebook. + cell_index: + The index of the reference cell, or None to append at the end. + add_above: + If True, insert above the reference cell; if False, insert below. + + Returns: + The index where the new cell should be inserted. + """ + if cell_index is None: + insert_index = cells_count + else: + if not (0 <= cell_index < cells_count): + cell_index = max(0, min(cell_index, cells_count)) + insert_index = cell_index if add_above else cell_index + 1 + return insert_index + + +async def create_notebook(file_path: str) -> str: + """Creates a new empty Jupyter notebook at the specified file path. + + This function creates a new empty notebook with proper nbformat structure. + If the file already exists, it will return an error message. + + Args: + file_path: + The path where the new notebook should be created. + + Returns: + A success message or error message. + """ + try: + file_path = normalize_filepath(file_path) + + # Check if file already exists + if os.path.exists(file_path): + return f"Error: File already exists at {file_path}" + + # Ensure the directory exists + directory = os.path.dirname(file_path) + if directory and not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + + # Create a new empty notebook + notebook = nbformat.v4.new_notebook() + + # Write the notebook to the file + with open(file_path, "w", encoding="utf-8") as f: + nbformat.write(notebook, f) + + return f"Successfully created new notebook at {file_path}" + + except Exception as e: + return f"Error: Failed to create notebook: {str(e)}" + +toolkit = [ + read_notebook_cells, + add_cell, + insert_cell, + delete_cell, + edit_cell, + create_notebook +] \ No newline at end of file diff --git a/jupyter_ai_jupyternaut/jupyternaut/toolkits/utils.py b/jupyter_ai_jupyternaut/jupyternaut/toolkits/utils.py new file mode 100644 index 0000000..2a834fe --- /dev/null +++ b/jupyter_ai_jupyternaut/jupyternaut/toolkits/utils.py @@ -0,0 +1,374 @@ +import functools +import inspect +import os +from pathlib import Path +from typing import Optional +from urllib.parse import unquote + +from jupyter_server.auth.identity import User +from jupyter_server.serverapp import ServerApp +from pycrdt import Awareness + + +def get_serverapp(): + """Returns the server app from the request context""" + + server = ServerApp.instance() + return server + + +def normalize_filepath(file_path: str) -> str: + """ + Normalizes a file path for Jupyter applications to return an absolute path. + + Handles various input formats: + - Relative paths from current working directory + - URL-encoded relative paths (common in Jupyter contexts) + - Absolute paths (returned as-is after normalization) + + Args: + file_path: Path in any of the supported formats + + Returns: + Absolute path to the file + + Example: + >>> normalize_filepath("notebooks/my%20notebook.ipynb") + "/current/working/dir/notebooks/my notebook.ipynb" + >>> normalize_filepath("/absolute/path/file.ipynb") + "/absolute/path/file.ipynb" + >>> normalize_filepath("relative/file.ipynb") + "/current/working/dir/relative/file.ipynb" + """ + if not file_path or not file_path.strip(): + raise ValueError("file_path cannot be empty") + + # URL decode the path in case it contains encoded characters + decoded_path = unquote(file_path) + + # Convert to Path object for easier manipulation + path = Path(decoded_path) + + # If already absolute, just normalize and return + if path.is_absolute(): + return str(path.resolve()) + + # For relative paths, get the Jupyter server's root directory + try: + serverapp = get_serverapp() + root_dir = serverapp.root_dir + except Exception: + # Fallback to current working directory if server app is not available + root_dir = os.getcwd() + + # Resolve relative path against the root directory + resolved_path = Path(root_dir) / path + return str(resolved_path.resolve()) + + +async def get_jupyter_ydoc(file_id: str): + """Returns the notebook ydoc + + Args: + file_id: The file ID for the document + + Returns: + `YNotebook` ydoc for the notebook + """ + serverapp = get_serverapp() + yroom_manager = serverapp.web_app.settings["yroom_manager"] + room_id = f"json:notebook:{file_id}" + + if yroom_manager.has_room(room_id): + yroom = yroom_manager.get_room(room_id) + notebook = await yroom.get_jupyter_ydoc() + return notebook + + +async def get_global_awareness() -> Optional[Awareness]: + serverapp = get_serverapp() + yroom_manager = serverapp.web_app.settings["yroom_manager"] + + room_id = "JupyterLab:globalAwareness" + if yroom_manager.has_room(room_id): + yroom = yroom_manager.get_room(room_id) + return yroom.get_awareness() + + # Return None if room doesn't exist + return None + + +async def get_file_id(file_path: str) -> str: + """Returns the file_id for the document + + Args: + file_path: + absolute path to the document file + + Returns: + The file ID of the document + """ + normalized_file_path = normalize_filepath(file_path) + + serverapp = get_serverapp() + file_id_manager = serverapp.web_app.settings["file_id_manager"] + file_id = file_id_manager.get_id(normalized_file_path) + + return file_id + + +def collaborative_tool(user: User): + """ + Decorator factory to enable collaborative awareness for toolkit functions. + + This decorator automatically sets up user awareness in the global + and notebook-specific awareness systems when functions are called. + It enables real-time collaborative features by making the user's + presence visible to other users in the same Jupyter environment. + + Args: + user: Optional user dictionary with user information. If None, no awareness is set. + Should contain keys like 'name', 'color', 'display_name', etc. + + Returns: + Decorator function that wraps the target function with collaborative awareness. + + Example: + >>> user_info = { + ... "name": "Alice", + ... "color": "var(--jp-collaborator-color1)", + ... "display_name": "Alice Smith" + ... } + >>> + >>> @collaborative_tool(user=user_info) + ... async def my_notebook_tool(file_path: str, content: str): + ... # Your tool implementation here + ... return f"Processed {file_path}" + """ + + def decorator(tool_func): + @functools.wraps(tool_func) + async def wrapper(*args, **kwargs): + # Skip awareness if no user provided + + # Get serverapp for logging + try: + serverapp = get_serverapp() + logger = serverapp.log + except Exception: + logger = None + + # Extract file_path from tool function arguments for notebook-specific awareness + file_path = None + try: + # Try to find file_path in kwargs first + if "file_path" in kwargs: + file_path = kwargs["file_path"] + else: + # Try to find file_path in positional args by inspecting the function signature + sig = inspect.signature(tool_func) + param_names = list(sig.parameters.keys()) + + # Look for file_path parameter + if "file_path" in param_names and len(args) > param_names.index("file_path"): + file_path = args[param_names.index("file_path")] + except Exception as e: + # Log error in file_path detection + if logger: + logger.warning(f"Error detecting file_path in collaborative_tool: {e}") + + # Set notebook-specific collaborative awareness if we have a file_path + if file_path and file_path.endswith(".ipynb"): + try: + file_id = await get_file_id(file_path) + ydoc = await get_jupyter_ydoc(file_id) + + if ydoc: + # Set the local user field in the notebook's awareness + ydoc.awareness.set_local_state_field("user", user) + except Exception as e: + # Log error but don't block tool execution + if logger: + logger.warning( + f"Error setting notebook awareness in collaborative_tool: {e}" + ) + + # Set global awareness + try: + g_awareness = await get_global_awareness() + if g_awareness: + g_awareness.set_local_state( + { + "user": user, + "current": file_path or "", + "documents": [file_path] if file_path else [], + } + ) + except Exception as e: + # Log error but don't block tool execution + if logger: + logger.warning(f"Error setting global awareness in collaborative_tool: {e}") + + # Execute the original tool function + return await tool_func(*args, **kwargs) + + return wrapper + + return decorator + + +def notebook_json_to_md(notebook_json: dict, include_outputs: bool = True) -> str: + """Converts a notebook json dict to markdown string using a custom format. + + Args: + notebook_json: The notebook JSON dictionary + include_outputs: Whether to include cell outputs in the markdown. Default is True. + + Returns: + Markdown string representation of the notebook + + Example: + ```markdown + ```yaml + kernelspec: + display_name: Python 3 + language: python + name: python3 + ``` + + ### Cell 0 + + #### Metadata + ```yaml + type: code + execution_count: 1 + ``` + + #### Source + ```python + print("Hello world") + ``` + + #### Output + ``` + Hello world + ``` + ``` + """ + # Extract notebook metadata + md_parts = [] + + # Add notebook metadata at the top + md_parts.append(metadata_to_md(notebook_json.get("metadata", {}))) + + # Process all cells + for i, cell in enumerate(notebook_json.get("cells", [])): + md_parts.append(cell_to_md(cell, index=i, include_outputs=include_outputs)) + + # Join all parts with double newlines + return "\n\n".join(md_parts) + + +def metadata_to_md(metadata_json: dict) -> str: + """Converts notebook or cell metadata to markdown string in YAML format. + + Args: + metadata_json: The metadata JSON dictionary + + Returns: + Markdown string with YAML formatted metadata + """ + import yaml # type: ignore[import-untyped] + + yaml_str = yaml.dump(metadata_json, default_flow_style=False) + return f"```yaml\n{yaml_str}```" + + +def cell_to_md(cell_json: dict, index: int = 0, include_outputs: bool = True) -> str: + """Converts notebook cell to markdown string. + + Args: + cell_json: The cell JSON dictionary + index: Cell index number for the heading + include_outputs: Whether to include cell outputs in the markdown + + Returns: + Markdown string representation of the cell + """ + md_parts = [] + + # Add cell heading with index + md_parts.append(f"### Cell {index}") + + # Add metadata section + md_parts.append("#### Metadata") + metadata = { + "type": cell_json.get("cell_type"), + "execution_count": cell_json.get("execution_count"), + } + # Filter out None values + metadata = {k: v for k, v in metadata.items() if v is not None} + # Add any additional metadata from the cell + if "metadata" in cell_json: + for key, value in cell_json["metadata"].items(): + metadata[key] = value + + md_parts.append(metadata_to_md(metadata)) + + # Add source section + md_parts.append("#### Source") + source = "".join(cell_json.get("source", [])) + + if cell_json.get("cell_type") == "code": + # For code cells, use python code block + md_parts.append(f"```python\n{source}```") + else: + # For markdown cells, use regular code block + md_parts.append(f"```\n{source}```") + + # Add output section if available and requested + if ( + include_outputs + and cell_json.get("cell_type") == "code" + and "outputs" in cell_json + and cell_json["outputs"] + ): + md_parts.append("#### Output") + md_parts.append(format_outputs(cell_json["outputs"])) + + return "\n\n".join(md_parts) + + +def format_outputs(outputs: list) -> str: + """Formats cell outputs into markdown. + + Args: + outputs: List of cell output dictionaries + + Returns: + Formatted markdown string of the outputs + """ + result = [] + + for output in outputs: + output_type = output.get("output_type") + + if output_type == "stream": + text = "".join(output.get("text", [])) + result.append(f"```\n{text}```") + + elif output_type == "execute_result" or output_type == "display_data": + data = output.get("data", {}) + + # Handle text/plain output + if "text/plain" in data: + text = "".join(data["text/plain"]) + result.append(f"```\n{text}```") + + # TODO: Add other mime types + + elif output_type == "error": + traceback = "\n".join(output.get("traceback", [])) + result.append(f"```\n{traceback}```") + + return "\n\n".join(result) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ede857b..5a5efe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,16 @@ dependencies = [ "deepmerge>=2.0,<3", # NOTE: Make sure to update the corresponding dependency in # `packages/jupyter-ai/package.json` to match the version range below - "jupyterlab-chat>=0.18.0,<0.19.0", + "jupyterlab-chat>=0.18.0", "jupyter_ai_litellm>=0.0.1", "jinja2>=3.0,<4", "python_dotenv>=1,<2", "jupyter_ai_persona_manager>=0.0.1", + "langchain>=1.0.0", + "langgraph-checkpoint-sqlite>=3.0.0", + "aiosqlite>=0.20", + "jupyter_server_documents>=0.1.0a8", + "jupyterlab-commands-toolkit>=0.1.2" ] dynamic = ["version", "description", "authors", "urls", "keywords"]