diff --git a/.gitignore b/.gitignore index 0ab9995..331c962 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,7 @@ wheels/ .venv *.mcpb notebook.ipynb +# Performance log files (branch-name.log) +*.log .vscode + diff --git a/src/android_mcp/__main__.py b/src/android_mcp/__main__.py index bc3f454..58401aa 100644 --- a/src/android_mcp/__main__.py +++ b/src/android_mcp/__main__.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from textwrap import dedent from typing import Literal, Optional -import asyncio import os from fastmcp import FastMCP @@ -185,7 +184,6 @@ def _connect_preferred_device() -> None: @asynccontextmanager async def lifespan(app: FastMCP): """Runs initialization code before the server starts and cleanup code after it shuts down.""" - await asyncio.sleep(1) yield diff --git a/src/android_mcp/mobile/service.py b/src/android_mcp/mobile/service.py index c45da89..9f9a25a 100644 --- a/src/android_mcp/mobile/service.py +++ b/src/android_mcp/mobile/service.py @@ -1,9 +1,11 @@ from android_mcp.mobile.views import MobileState from android_mcp.tree.service import Tree +from android_mcp.perf_log import timed, log_separator import uiautomator2 as u2 from io import BytesIO from PIL import Image import subprocess +import threading import base64 import os from typing import Optional @@ -90,19 +92,19 @@ def get_device(self): return self.device def capture_data(self, use_vision: bool = True): - import threading data = {} def get_xml(): try: - data['xml'] = self.device.dump_hierarchy() + with timed("capture_data.dump_hierarchy"): + data['xml'] = self.device.dump_hierarchy() except Exception as e: data['xml_error'] = e def get_img(): try: - # Use format="pillow" to ensure we get a PIL image immediately - data['img'] = self.device.screenshot(format="pillow") + with timed("capture_data.screenshot"): + data['img'] = self.device.screenshot(format="pillow") except Exception as e: data['img_error'] = e @@ -112,8 +114,9 @@ def get_img(): for t in threads: t.start() - for t in threads: - t.join() + with timed("capture_data.total_parallel"): + for t in threads: + t.join() if 'xml_error' in data: raise data['xml_error'] @@ -124,18 +127,29 @@ def get_img(): def get_state(self, use_vision=False, as_bytes: bool = False, as_base64: bool = False, use_annotation: bool = True): try: - xml_data, screenshot_data = self.capture_data(use_vision=use_vision) - tree = Tree(self) - tree_state = tree.get_state(xml_data=xml_data) + log_separator(f"get_state use_vision={use_vision}") + with timed("get_state.capture_data"): + xml_data, screenshot_data = self.capture_data(use_vision=use_vision) + with timed("get_state.tree_state"): + tree = Tree(self) + tree_state = tree.get_state(xml_data=xml_data) if use_vision: nodes = tree_state.interactive_elements + scale = float(os.getenv("SCREENSHOT_SCALE", "0.5")) + w, h = screenshot_data.size + with timed("get_state.screenshot_resize"): + screenshot_data = screenshot_data.resize( + (int(w * scale), int(h * scale)), Image.Resampling.LANCZOS + ) if use_annotation: - screenshot = tree.annotated_screenshot(nodes=nodes, scale=1.0, screenshot=screenshot_data) + with timed("get_state.annotated_screenshot"): + screenshot = tree.annotated_screenshot(nodes=nodes, scale=scale, screenshot=screenshot_data) else: screenshot = screenshot_data if os.getenv("SCREENSHOT_QUANTIZED") in ["1", "yes", "true", True]: - screenshot = self.quantized_screenshot(screenshot) + with timed("get_state.quantize"): + screenshot = self.quantized_screenshot(screenshot) if as_base64: screenshot = self.as_base64(screenshot) @@ -172,7 +186,8 @@ def screenshot_in_bytes(self,screenshot:Image.Image)->bytes: if screenshot is None: raise ValueError("Screenshot is None") io=BytesIO() - screenshot.save(io,format='PNG') + with timed("screenshot_in_bytes.png_save"): + screenshot.save(io,format='PNG',compress_level=1) bytes=io.getvalue() if len(bytes) == 0: raise ValueError("Screenshot conversion resulted in empty bytes.") @@ -185,7 +200,8 @@ def as_base64(self,screenshot:Image.Image)->str: if screenshot is None: raise ValueError("Screenshot is None") io=BytesIO() - screenshot.save(io,format='PNG') + with timed("as_base64.png_save"): + screenshot.save(io,format='PNG',compress_level=1) bytes=io.getvalue() if len(bytes) == 0: raise ValueError("Screenshot conversion resulted in empty bytes.") diff --git a/src/android_mcp/perf_log.py b/src/android_mcp/perf_log.py new file mode 100644 index 0000000..8b87c27 --- /dev/null +++ b/src/android_mcp/perf_log.py @@ -0,0 +1,48 @@ +""" +Performance instrumentation utility. +Timing data is appended to {branch-name}.log in the working directory. +""" +import subprocess +import time +from contextlib import contextmanager + + +def _get_branch_name() -> str: + try: + result = subprocess.run( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + capture_output=True, text=True, timeout=5 + ) + branch = result.stdout.strip() + if branch and branch != 'HEAD': + return branch.replace('/', '-') + except Exception: + pass + return 'unknown' + + +_LOG_FILE = f"{_get_branch_name()}.log" + + +def _write(line: str) -> None: + try: + with open(_LOG_FILE, 'a') as f: + f.write(line + '\n') + except Exception: + pass + + +def log_separator(label: str = "") -> None: + """Write a section separator, e.g. at the start of a Snapshot call.""" + import datetime + ts = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] + _write(f"\n--- {label} [{ts}] ---") + + +@contextmanager +def timed(label: str): + """Context manager that times a block and appends the result to the log file.""" + t0 = time.perf_counter() + yield + elapsed_ms = (time.perf_counter() - t0) * 1000 + _write(f" {label}: {elapsed_ms:.1f}ms") diff --git a/src/android_mcp/tree/service.py b/src/android_mcp/tree/service.py index 6f0909d..4754d3e 100644 --- a/src/android_mcp/tree/service.py +++ b/src/android_mcp/tree/service.py @@ -1,6 +1,7 @@ from android_mcp.tree.views import TreeState, ElementNode, CenterCord, BoundingBox from android_mcp.tree.utils import extract_cordinates,get_center_cordinates from android_mcp.tree.config import INTERACTIVE_CLASSES +from android_mcp.perf_log import timed from PIL import Image, ImageFont, ImageDraw from xml.etree.ElementTree import Element from xml.etree import ElementTree @@ -20,7 +21,8 @@ def __init__(self,mobile:'Mobile'): def get_element_tree(self, xml_data=None)->'Element': tree_string = xml_data if xml_data else self.mobile.device.dump_hierarchy() - return ElementTree.fromstring(tree_string) + with timed("get_element_tree.fromstring"): + return ElementTree.fromstring(tree_string) def get_state(self, xml_data=None)->TreeState: interactive_elements=self.get_interactive_elements(xml_data=xml_data) @@ -29,23 +31,25 @@ def get_state(self, xml_data=None)->TreeState: def get_interactive_elements(self, xml_data=None)->list: interactive_elements=[] element_tree = self.get_element_tree(xml_data=xml_data) - nodes=element_tree.findall('.//node[@enabled="true"]') - for node in nodes: - if self.is_interactive(node): - x1,y1,x2,y2 = extract_cordinates(node) - name=self.get_element_name(node) - if not name: - continue - x_center,y_center = get_center_cordinates((x1,y1,x2,y2)) - raw_id=node.get('resource-id','') - short_id=raw_id.split('/')[-1] if '/' in raw_id else raw_id - interactive_elements.append(ElementNode(**{ - 'name':name, - 'class_name':node.get('class'), - 'coordinates':CenterCord(x=x_center,y=y_center), - 'bounding_box':BoundingBox(x1=x1,y1=y1,x2=x2,y2=y2), - 'resource_id':short_id - })) + with timed("get_interactive_elements.findall"): + nodes=element_tree.findall('.//node[@enabled="true"]') + with timed("get_interactive_elements.filter_loop"): + for node in nodes: + if self.is_interactive(node): + x1,y1,x2,y2 = extract_cordinates(node) + name=self.get_element_name(node) + if not name: + continue + x_center,y_center = get_center_cordinates((x1,y1,x2,y2)) + raw_id=node.get('resource-id','') + short_id=raw_id.split('/')[-1] if '/' in raw_id else raw_id + interactive_elements.append(ElementNode(**{ + 'name':name, + 'class_name':node.get('class'), + 'coordinates':CenterCord(x=x_center,y=y_center), + 'bounding_box':BoundingBox(x1=x1,y1=y1,x2=x2,y2=y2), + 'resource_id':short_id + })) return interactive_elements def get_element_name(self, node) -> str: @@ -95,7 +99,8 @@ def is_interactive(self, node) -> bool: def annotated_screenshot(self, nodes: list[ElementNode],scale:float=0.7, screenshot=None) -> Image.Image: if screenshot is None: - screenshot = self.mobile.get_screenshot(scale=scale) + with timed("annotated_screenshot.get_screenshot"): + screenshot = self.mobile.get_screenshot(scale=scale) draw = ImageDraw.Draw(screenshot) font_size = 12 @@ -135,7 +140,8 @@ def draw_annotation(label, node: ElementNode): draw.rectangle([(label_x1, label_y1), (label_x2, label_y2)], fill=color) draw.text((label_x1 + 2, label_y1 + 2), str(label), fill=(255, 255, 255), font=font) - for i, node in enumerate(nodes): - draw_annotation(i, node) + with timed("annotated_screenshot.draw_all"): + for i, node in enumerate(nodes): + draw_annotation(i, node) return screenshot diff --git a/src/android_mcp/tree/utils.py b/src/android_mcp/tree/utils.py index 3dc630e..939e8ee 100644 --- a/src/android_mcp/tree/utils.py +++ b/src/android_mcp/tree/utils.py @@ -1,9 +1,11 @@ import re +_BOUNDS_RE = re.compile(r'\[(\d+),(\d+)]\[(\d+),(\d+)]') + def extract_cordinates(node): attributes = node.attrib bounds=attributes.get('bounds') - match = re.search(r'\[(\d+),(\d+)]\[(\d+),(\d+)]', bounds) + match = _BOUNDS_RE.search(bounds) if match: x1, y1, x2, y2 = map(int, match.groups()) return x1, y1, x2, y2