diff --git a/src/main/java/net/sourceforge/plantuml/servlet/mcp/McpServlet.java b/src/main/java/net/sourceforge/plantuml/servlet/mcp/McpServlet.java index 118b48ba..86de85fe 100644 --- a/src/main/java/net/sourceforge/plantuml/servlet/mcp/McpServlet.java +++ b/src/main/java/net/sourceforge/plantuml/servlet/mcp/McpServlet.java @@ -119,14 +119,20 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) try { JsonObject requestBody = readJsonRequest(request); - if (pathInfo.equals("/render")) { + if (pathInfo.equals("/check")) { + handleCheck(requestBody, response); + } else if (pathInfo.equals("/render")) { handleRender(requestBody, response); + } else if (pathInfo.equals("/metadata")) { + handleMetadata(requestBody, response); } else if (pathInfo.equals("/render-url")) { handleRenderUrl(requestBody, response); } else if (pathInfo.equals("/analyze")) { handleAnalyze(requestBody, response); } else if (pathInfo.equals("/workspace/create")) { handleWorkspaceCreate(requestBody, response); + } else if (pathInfo.equals("/workspace/put")) { + handleWorkspaceUpdate(requestBody, response); } else if (pathInfo.equals("/workspace/update")) { handleWorkspaceUpdate(requestBody, response); } else if (pathInfo.equals("/workspace/get")) { @@ -138,6 +144,10 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } else { sendError(response, HttpServletResponse.SC_NOT_FOUND, "Endpoint not found"); } + } catch (com.google.gson.JsonSyntaxException e) { + // Handle JSON parsing errors + sendError(response, HttpServletResponse.SC_BAD_REQUEST, + "Invalid JSON: " + e.getMessage()); } catch (Exception e) { // Log error (servlet container will handle logging) sendError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, @@ -240,6 +250,106 @@ private void handleExamplesGet(HttpServletRequest request, HttpServletResponse r sendJson(response, result); } + private void handleCheck(JsonObject requestBody, HttpServletResponse response) + throws IOException { + String source = getJsonString(requestBody, "source", null); + if (source == null || source.isEmpty()) { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Missing 'source' field"); + return; + } + + try { + // Try to parse the diagram to check for syntax errors + SourceStringReader reader = new SourceStringReader(source); + // Use NullOutputStream to avoid generating actual image data + reader.outputImage(new net.sourceforge.plantuml.servlet.utility.NullOutputStream(), + 0, new FileFormatOption(FileFormat.PNG)); + + Map result = new HashMap<>(); + result.put("ok", true); + result.put("errors", new Object[0]); + sendJson(response, result); + } catch (Exception e) { + Map result = new HashMap<>(); + result.put("ok", false); + Map error = new HashMap<>(); + error.put("line", 0); + error.put("message", e.getMessage() != null ? e.getMessage() : "Syntax error"); + result.put("errors", new Object[]{error}); + sendJson(response, result); + } + } + + private void handleMetadata(JsonObject requestBody, HttpServletResponse response) + throws IOException { + String source = getJsonString(requestBody, "source", null); + if (source == null || source.isEmpty()) { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, "Missing 'source' field"); + return; + } + + try { + // Extract basic metadata from the source + Map result = new HashMap<>(); + + // Parse participants/entities from the source + java.util.List participants = new java.util.ArrayList<>(); + String[] lines = source.split("\n"); + for (String line : lines) { + // Simple parsing for common diagram elements + line = line.trim(); + if (line.matches("^[a-zA-Z0-9_]+\\s*->.*") || line.matches(".*->\\s*[a-zA-Z0-9_]+.*")) { + // Extract participant names from arrow notations + String[] parts = line.split("->"); + for (String part : parts) { + String trimmed = part.trim(); + if (!trimmed.isEmpty()) { + String[] tokens = trimmed.split("\\s+"); + if (tokens.length > 0) { + String name = tokens[0].replaceAll("[^a-zA-Z0-9_]", ""); + if (!name.isEmpty() && !participants.contains(name)) { + participants.add(name); + } + } + } + } + } else if (line.matches("^(class|interface|entity|participant)\\s+[a-zA-Z0-9_]+.*")) { + String[] parts = line.split("\\s+"); + if (parts.length >= 2) { + String name = parts[1].replaceAll("[^a-zA-Z0-9_]", ""); + if (!name.isEmpty() && !participants.contains(name)) { + participants.add(name); + } + } + } + } + + result.put("participants", participants.toArray(new String[0])); + result.put("directives", new String[0]); + + // Detect diagram type + String diagramType = "unknown"; + if (source.contains("@startuml")) { + if (source.contains("->") || source.contains("participant")) { + diagramType = "sequence"; + } else if (source.contains("class") || source.contains("interface")) { + diagramType = "class"; + } else if (source.contains("state")) { + diagramType = "state"; + } else if (source.contains("usecase") || source.contains("actor")) { + diagramType = "usecase"; + } + } + result.put("diagramType", diagramType); + result.put("warnings", new String[0]); + + sendJson(response, result); + } catch (Exception e) { + sendError(response, HttpServletResponse.SC_BAD_REQUEST, + "Metadata extraction failed: " + e.getMessage()); + } + } + private void handleRender(JsonObject requestBody, HttpServletResponse response) throws IOException { String source = getJsonString(requestBody, "source", null); @@ -258,20 +368,22 @@ private void handleRender(JsonObject requestBody, HttpServletResponse response) reader.outputImage(outputStream, 0, new FileFormatOption(fileFormat)); byte[] imageBytes = outputStream.toByteArray(); - String dataUrl = formatDataUrl(imageBytes, fileFormat); - String sha256 = computeSha256(imageBytes); + String dataBase64 = Base64.getEncoder().encodeToString(imageBytes); Map result = new HashMap<>(); - result.put("status", "ok"); + result.put("ok", true); result.put("format", format); - result.put("dataUrl", dataUrl); - result.put("renderTimeMs", System.currentTimeMillis() - startTime); - result.put("sha256", sha256); + result.put("dataBase64", dataBase64); sendJson(response, result); } catch (Exception e) { - sendError(response, HttpServletResponse.SC_BAD_REQUEST, - "Rendering failed: " + e.getMessage()); + Map errorResult = new HashMap<>(); + errorResult.put("ok", false); + errorResult.put("errors", new Object[]{ + java.util.Collections.singletonMap("message", "Rendering failed: " + e.getMessage()) + }); + response.setStatus(HttpServletResponse.SC_OK); + sendJson(response, errorResult); } } diff --git a/src/test/java/net/sourceforge/plantuml/servlet/mcp/McpServletTest.java b/src/test/java/net/sourceforge/plantuml/servlet/mcp/McpServletTest.java new file mode 100644 index 00000000..f02a8f85 --- /dev/null +++ b/src/test/java/net/sourceforge/plantuml/servlet/mcp/McpServletTest.java @@ -0,0 +1,196 @@ +package net.sourceforge.plantuml.servlet.mcp; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; + +import org.junit.jupiter.api.Test; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import net.sourceforge.plantuml.servlet.utils.WebappTestCase; + +/** + * Unit tests for McpServlet as specified in the issue requirements. + * These tests use the WebappTestCase framework instead of direct servlet mocking + * to avoid dependency conflicts. + */ +public class McpServletTest extends WebappTestCase { + + private static final Gson GSON = new Gson(); + + /** + * Helper method to make a POST request with JSON body. + */ + private HttpURLConnection postJson(String path, String json) throws IOException { + URL url = new URL(getServerUrl() + path); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setDoOutput(true); + + try (OutputStream os = conn.getOutputStream()) { + byte[] input = json.getBytes(StandardCharsets.UTF_8); + os.write(input, 0, input.length); + } + + return conn; + } + + /** + * Helper method to extract workspaceId from JSON response. + */ + private String extractWorkspaceId(String json) { + JsonObject obj = GSON.fromJson(json, JsonObject.class); + if (obj.has("workspaceId")) { + return obj.get("workspaceId").getAsString(); + } + return null; + } + + /** + * Test: check endpoint accepts valid diagram. + */ + @Test + void checkEndpointShouldReturnOkForValidDiagram() throws Exception { + String json = "{ \"source\": \"@startuml\\nAlice -> Bob\\n@enduml\" }"; + + HttpURLConnection conn = postJson("/mcp/check", json); + int responseCode = conn.getResponseCode(); + + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + assertEquals(200, responseCode); + String body = getContentText(conn); + + assertTrue(body.contains("\"ok\":true")); + assertTrue(body.contains("\"errors\":[]")); + } + + /** + * Test: check endpoint should report syntax errors. + */ + @Test + void checkEndpointShouldReportErrors() throws Exception { + String json = "{ \"source\": \"@startuml\\nThis is wrong\\n@enduml\" }"; + + HttpURLConnection conn = postJson("/mcp/check", json); + int responseCode = conn.getResponseCode(); + + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + String body = getContentText(conn); + assertTrue(body.contains("\"ok\":false")); + assertTrue(body.contains("errors")); + } + + /** + * Test: render endpoint returns Base64 PNG. + */ + @Test + void renderEndpointReturnsPngBase64() throws Exception { + String json = "{ \"source\": \"@startuml\\nAlice -> Bob\\n@enduml\" }"; + + HttpURLConnection conn = postJson("/mcp/render", json); + int responseCode = conn.getResponseCode(); + + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + String body = getContentText(conn); + assertTrue(body.contains("\"format\":\"png\"")); + assertTrue(body.contains("\"dataBase64\"")); + } + + /** + * Test: metadata endpoint returns participants. + */ + @Test + void metadataEndpointReturnsParticipants() throws Exception { + String json = "{ \"source\": \"@startuml\\nAlice -> Bob\\n@enduml\" }"; + + HttpURLConnection conn = postJson("/mcp/metadata", json); + int responseCode = conn.getResponseCode(); + + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + String body = getContentText(conn); + assertTrue(body.contains("Alice")); + assertTrue(body.contains("Bob")); + } + + /** + * Test: workspace lifecycle. + */ + @Test + void workspaceLifecycle() throws Exception { + String sessionId = "test-session-" + System.currentTimeMillis(); + + // 1) create workspace (diagram) + String createJson = "{ \"sessionId\":\"" + sessionId + "\", " + + "\"name\":\"test.puml\", " + + "\"source\":\"@startuml\\nAlice->Bob\\n@enduml\" }"; + HttpURLConnection r1 = postJson("/mcp/workspace/create", createJson); + + int responseCode = r1.getResponseCode(); + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + assertEquals(200, responseCode); + String body1 = getContentText(r1); + + // Extract diagramId from response + JsonObject createResp = GSON.fromJson(body1, JsonObject.class); + String diagramId = createResp.get("diagramId").getAsString(); + assertNotNull(diagramId); + + // 2) put file (update diagram) + String putJson = "{ \"sessionId\":\"" + sessionId + "\", " + + "\"diagramId\":\"" + diagramId + "\", " + + "\"source\":\"@startuml\\nAlice->Charlie\\n@enduml\" }"; + HttpURLConnection r2 = postJson("/mcp/workspace/put", putJson); + assertEquals(200, r2.getResponseCode()); + + // 3) render file + String renderJson = "{ \"sessionId\":\"" + sessionId + "\", " + + "\"diagramId\":\"" + diagramId + "\" }"; + HttpURLConnection r3 = postJson("/mcp/workspace/render", renderJson); + + String body3 = getContentText(r3); + assertTrue(body3.contains("\"dataBase64\"")); + } + + /** + * Test: invalid JSON must return 400. + */ + @Test + void invalidJsonShouldReturn400() throws Exception { + HttpURLConnection conn = postJson("/mcp/check", "{ invalid json }"); + + int responseCode = conn.getResponseCode(); + if (responseCode == 404) { + // MCP not enabled, skip this test + return; + } + + assertEquals(400, responseCode); + } +}