Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ wheels/
.venv
*.mcpb
notebook.ipynb
# Performance log files (branch-name.log)
*.log
.vscode
45 changes: 30 additions & 15 deletions src/android_mcp/mobile/service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from android_mcp.mobile.views import MobileState
from android_mcp.tree.service import Tree
from android_mcp.perf_log import timed, log_separator
import uiautomator2 as u2
from io import BytesIO
from PIL import Image
import subprocess
import threading
import base64
import os
from typing import Optional
Expand Down Expand Up @@ -90,19 +92,19 @@ def get_device(self):
return self.device

def capture_data(self, use_vision: bool = True):
import threading
data = {}

def get_xml():
try:
data['xml'] = self.device.dump_hierarchy()
with timed("capture_data.dump_hierarchy"):
data['xml'] = self.device.dump_hierarchy()
except Exception as e:
data['xml_error'] = e

def get_img():
try:
# Use format="pillow" to ensure we get a PIL image immediately
data['img'] = self.device.screenshot(format="pillow")
with timed("capture_data.screenshot"):
data['img'] = self.device.screenshot(format="pillow")
except Exception as e:
data['img_error'] = e

Expand All @@ -112,8 +114,9 @@ def get_img():

for t in threads:
t.start()
for t in threads:
t.join()
with timed("capture_data.total_parallel"):
for t in threads:
t.join()

if 'xml_error' in data:
raise data['xml_error']
Expand All @@ -124,18 +127,29 @@ def get_img():

def get_state(self, use_vision=False, as_bytes: bool = False, as_base64: bool = False, use_annotation: bool = True):
try:
xml_data, screenshot_data = self.capture_data(use_vision=use_vision)
tree = Tree(self)
tree_state = tree.get_state(xml_data=xml_data)
log_separator(f"get_state use_vision={use_vision}")
with timed("get_state.capture_data"):
xml_data, screenshot_data = self.capture_data(use_vision=use_vision)
with timed("get_state.tree_state"):
tree = Tree(self)
tree_state = tree.get_state(xml_data=xml_data)

if use_vision:
nodes = tree_state.interactive_elements
scale = float(os.getenv("SCREENSHOT_SCALE", "0.5"))
w, h = screenshot_data.size
with timed("get_state.screenshot_resize"):
screenshot_data = screenshot_data.resize(
(int(w * scale), int(h * scale)), Image.Resampling.LANCZOS
)
if use_annotation:
screenshot = tree.annotated_screenshot(nodes=nodes, scale=1.0, screenshot=screenshot_data)
with timed("get_state.annotated_screenshot"):
screenshot = tree.annotated_screenshot(nodes=nodes, scale=scale, screenshot=screenshot_data)
else:
screenshot = screenshot_data
if os.getenv("SCREENSHOT_QUANTIZED") in ["1", "yes", "true", True]:
screenshot = self.quantized_screenshot(screenshot)
with timed("get_state.quantize"):
screenshot = self.quantized_screenshot(screenshot)

if as_base64:
screenshot = self.as_base64(screenshot)
Expand All @@ -146,7 +160,7 @@ def get_state(self, use_vision=False, as_bytes: bool = False, as_base64: bool =
return MobileState(tree_state=tree_state, screenshot=screenshot)
except Exception as e:
raise RuntimeError(f"Failed to get device state: {e}")

def get_screenshot(self,scale:float=0.7)->Image.Image:
try:
screenshot=self.device.screenshot()
Expand All @@ -172,7 +186,8 @@ def screenshot_in_bytes(self,screenshot:Image.Image)->bytes:
if screenshot is None:
raise ValueError("Screenshot is None")
io=BytesIO()
screenshot.save(io,format='PNG')
with timed("screenshot_in_bytes.png_save"):
screenshot.save(io,format='PNG',compress_level=1)
bytes=io.getvalue()
if len(bytes) == 0:
raise ValueError("Screenshot conversion resulted in empty bytes.")
Expand All @@ -185,12 +200,12 @@ def as_base64(self,screenshot:Image.Image)->str:
if screenshot is None:
raise ValueError("Screenshot is None")
io=BytesIO()
screenshot.save(io,format='PNG')
with timed("as_base64.png_save"):
screenshot.save(io,format='PNG',compress_level=1)
bytes=io.getvalue()
if len(bytes) == 0:
raise ValueError("Screenshot conversion resulted in empty bytes.")
return base64.b64encode(bytes).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Failed to convert screenshot to base64: {e}")


48 changes: 48 additions & 0 deletions src/android_mcp/perf_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Performance instrumentation utility.
Timing data is appended to {branch-name}.log in the working directory.
"""
import subprocess
import time
from contextlib import contextmanager


def _get_branch_name() -> str:
try:
result = subprocess.run(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
capture_output=True, text=True, timeout=5
)
branch = result.stdout.strip()
if branch and branch != 'HEAD':
return branch.replace('/', '-')
except Exception:
pass
return 'unknown'


_LOG_FILE = f"{_get_branch_name()}.log"


def _write(line: str) -> None:
try:
with open(_LOG_FILE, 'a') as f:
f.write(line + '\n')
except Exception:
pass


def log_separator(label: str = "") -> None:
"""Write a section separator, e.g. at the start of a Snapshot call."""
import datetime
ts = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
_write(f"\n--- {label} [{ts}] ---")


@contextmanager
def timed(label: str):
"""Context manager that times a block and appends the result to the log file."""
t0 = time.perf_counter()
yield
elapsed_ms = (time.perf_counter() - t0) * 1000
_write(f" {label}: {elapsed_ms:.1f}ms")
68 changes: 37 additions & 31 deletions src/android_mcp/tree/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from android_mcp.tree.views import TreeState, ElementNode, CenterCord, BoundingBox
from android_mcp.tree.utils import extract_cordinates,get_center_cordinates
from android_mcp.tree.config import INTERACTIVE_CLASSES
from android_mcp.perf_log import timed
from PIL import Image, ImageFont, ImageDraw
from xml.etree.ElementTree import Element
from xml.etree import ElementTree
Expand All @@ -20,71 +21,74 @@ def __init__(self,mobile:'Mobile'):

def get_element_tree(self, xml_data=None)->'Element':
tree_string = xml_data if xml_data else self.mobile.device.dump_hierarchy()
return ElementTree.fromstring(tree_string)

with timed("get_element_tree.fromstring"):
return ElementTree.fromstring(tree_string)

def get_state(self, xml_data=None)->TreeState:
interactive_elements=self.get_interactive_elements(xml_data=xml_data)
return TreeState(interactive_elements=interactive_elements)

def get_interactive_elements(self, xml_data=None)->list:
interactive_elements=[]
element_tree = self.get_element_tree(xml_data=xml_data)
nodes=element_tree.findall('.//node[@enabled="true"]')
for node in nodes:
if self.is_interactive(node):
x1,y1,x2,y2 = extract_cordinates(node)
name=self.get_element_name(node)
if not name:
continue
x_center,y_center = get_center_cordinates((x1,y1,x2,y2))
raw_id=node.get('resource-id','')
short_id=raw_id.split('/')[-1] if '/' in raw_id else raw_id
interactive_elements.append(ElementNode(**{
'name':name,
'class_name':node.get('class'),
'coordinates':CenterCord(x=x_center,y=y_center),
'bounding_box':BoundingBox(x1=x1,y1=y1,x2=x2,y2=y2),
'resource_id':short_id
}))
with timed("get_interactive_elements.findall"):
nodes=element_tree.findall('.//node[@enabled="true"]')
with timed("get_interactive_elements.filter_loop"):
for node in nodes:
if self.is_interactive(node):
x1,y1,x2,y2 = extract_cordinates(node)
name=self.get_element_name(node)
if not name:
continue
x_center,y_center = get_center_cordinates((x1,y1,x2,y2))
raw_id=node.get('resource-id','')
short_id=raw_id.split('/')[-1] if '/' in raw_id else raw_id
interactive_elements.append(ElementNode(**{
'name':name,
'class_name':node.get('class'),
'coordinates':CenterCord(x=x_center,y=y_center),
'bounding_box':BoundingBox(x1=x1,y1=y1,x2=x2,y2=y2),
'resource_id':short_id
}))
return interactive_elements

def get_element_name(self, node) -> str:
name = node.get('content-desc') or node.get('text')
if not name:
texts = []
fallback_texts = []

def collect_text(n):
# Check if this node is actionable (and not the root node we started with)
is_actionable = (n is not node) and (
n.get('clickable') == "true" or
n.get('clickable') == "true" or
n.get('long-clickable') == "true" or
n.get('checkable') == "true" or
n.get('scrollable') == "true")

val = n.get('text') or n.get('content-desc') or n.get('hint')

if is_actionable:
if val:
fallback_texts.append(val)
return # Stop recursing into actionable nodes

if val:
texts.append(val)

for child in n:
collect_text(child)

collect_text(node)

# Use primary texts if found, otherwise use fallback texts from actionable children
final_texts = texts if texts else fallback_texts
name = " ".join(final_texts).strip()
return name

def is_interactive(self, node) -> bool:
attributes = node.attrib
return (attributes.get('focusable') == "true" or
return (attributes.get('focusable') == "true" or
attributes.get('clickable') == "true" or
attributes.get('long-clickable') == "true" or
attributes.get('checkable') == "true" or
Expand All @@ -95,7 +99,8 @@ def is_interactive(self, node) -> bool:

def annotated_screenshot(self, nodes: list[ElementNode],scale:float=0.7, screenshot=None) -> Image.Image:
if screenshot is None:
screenshot = self.mobile.get_screenshot(scale=scale)
with timed("annotated_screenshot.get_screenshot"):
screenshot = self.mobile.get_screenshot(scale=scale)

draw = ImageDraw.Draw(screenshot)
font_size = 12
Expand Down Expand Up @@ -135,7 +140,8 @@ def draw_annotation(label, node: ElementNode):
draw.rectangle([(label_x1, label_y1), (label_x2, label_y2)], fill=color)
draw.text((label_x1 + 2, label_y1 + 2), str(label), fill=(255, 255, 255), font=font)

for i, node in enumerate(nodes):
draw_annotation(i, node)
with timed("annotated_screenshot.draw_all"):
for i, node in enumerate(nodes):
draw_annotation(i, node)

return screenshot