Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions memgraph-toolbox/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Below is a list of tools included in the toolbox, along with their descriptions:
7. `CypherTool` - Executes arbitrary [Cypher queries](https://memgraph.com/docs/querying) on a Memgraph database.
8. `ShowConstraintInfoTool` - Shows [constraint](https://memgraph.com/docs/fundamentals/constraints) information from a Memgraph database.
9. `ShowConfigTool` - Shows [configuration](https://memgraph.com/docs/database-management/configuration) information from a Memgraph database.
10. `NodeVectorSearchTool` - Searches the most similar nodes using the Memgraph's [vector search](https://memgraph.com/docs/querying/vector-search).
11. `NodeNeighborhoodTool` - Searches for the data attached to a given node using Memgraph's [deep-path traversals](https://memgraph.com/docs/advanced-algorithms/deep-path-traversal).

## Usage

Expand Down
2 changes: 2 additions & 0 deletions memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .tools.constraint import ShowConstraintInfoTool
from .tools.cypher import CypherTool
from .tools.index import ShowIndexInfoTool
from .tools.node_neighborhood import NodeNeighborhoodTool
from .tools.node_vector_search import NodeVectorSearchTool
from .tools.page_rank import PageRankTool
from .tools.schema import ShowSchemaInfoTool
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self, db: Memgraph):
self.add_tool(ShowConstraintInfoTool(db))
self.add_tool(CypherTool(db))
self.add_tool(ShowIndexInfoTool(db))
self.add_tool(NodeNeighborhoodTool(db))
self.add_tool(NodeVectorSearchTool(db))
self.add_tool(PageRankTool(db))
self.add_tool(ShowSchemaInfoTool(db))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_memgraph_toolbox():
tools = toolkit.get_all_tools()

# Check if we have all 9 tools
assert len(tools) == 10
assert len(tools) == 11

# Check for specific tool names
tool_names = [tool.name for tool in tools]
Expand All @@ -66,6 +66,7 @@ def test_memgraph_toolbox():
"show_schema_info",
"show_storage_info",
"show_triggers",
"node_neighborhood",
]

for expected_tool in expected_tools:
Expand Down
29 changes: 29 additions & 0 deletions memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..tools.constraint import ShowConstraintInfoTool
from ..tools.cypher import CypherTool
from ..tools.index import ShowIndexInfoTool
from ..tools.node_neighborhood import NodeNeighborhoodTool
from ..tools.node_vector_search import NodeVectorSearchTool
from ..tools.page_rank import PageRankTool
from ..tools.schema import ShowSchemaInfoTool
Expand Down Expand Up @@ -282,3 +283,31 @@ def test_node_vector_search_tool():
'MATCH (n:Person) WHERE "embedding" IN keys(n) DETACH DELETE n'
)
memgraph_client.query("DROP VECTOR INDEX my_index")


def test_node_neighborhood_tool():
"""Test the NodeNeighborhood tool."""
url = "bolt://localhost:7687"
user = ""
password = ""
memgraph_client = Memgraph(url=url, username=user, password=password)

label = "TestNodeNeighborhoodToolLabel"
memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;")
memgraph_client.query(
f"CREATE (p1:{label} {{id: 1}})-[:KNOWS]->(p2:{label} {{id: 2}}), (p2)-[:KNOWS]->(p3:{label} {{id: 3}});"
)
memgraph_client.query(
f"CREATE (p4:{label} {{id: 4}})-[:KNOWS]->(p5:{label} {{id: 5}});"
)
ids = memgraph_client.query(
f"MATCH (p1:{label} {{id:1}}) RETURN id(p1) AS node_id;"
)
assert len(ids) == 1
node_id = ids[0]["node_id"]

node_neighborhood_tool = NodeNeighborhoodTool(db=memgraph_client)
result = node_neighborhood_tool.call({"node_id": node_id, "max_distance": 2})
assert isinstance(result, list)
assert len(result) == 2
memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;")
1 change: 1 addition & 0 deletions memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Any, Dict, List

from ..api.memgraph import Memgraph
from ..api.tool import BaseTool


class NodeNeighborhoodTool(BaseTool):
"""
Tool for finding nodes within a specified neighborhood distance in Memgraph.
"""

def __init__(self, db: Memgraph):
super().__init__(
name="node_neighborhood",
description=(
"Finds nodes within a specified distance from a given node. "
"This tool explores the graph neighborhood around a starting node, "
"returning all nodes and relationships found within the specified radius."
),
input_schema={
"type": "object",
"properties": {
"node_id": {
"type": "string",
"description": "The ID of the starting node to find neighborhood around",
},
"max_distance": {
"type": "integer",
"description": "Maximum distance (hops) to search from the starting node. Default is 2.",
"default": 1,
},
"limit": {
"type": "integer",
"description": "Maximum number of nodes to return. Default is 100.",
"default": 100,
},
},
"required": ["node_id"],
},
)
self.db = db

def call(self, arguments: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Execute the neighborhood search and return the results."""
node_id = arguments["node_id"]
max_distance = arguments.get("max_distance", 1)
limit = arguments.get("limit", 100)

query = f"""MATCH (n)-[r*..{max_distance}]-(m) WHERE id(n) = {node_id} RETURN DISTINCT m LIMIT {limit};"""
try:
results = self.db.query(query, {})
processed_results = []
for record in results:
node_data = record["m"];
properties = {k: v for k, v in node_data.items()}
processed_results.append(properties)
return processed_results
except Exception as e:
return [{"error": f"Failed to find neighborhood: {str(e)}"}]