diff --git a/src/workshop_mcp/complexity_analysis/__init__.py b/src/workshop_mcp/complexity_analysis/__init__.py new file mode 100644 index 0000000..1736ace --- /dev/null +++ b/src/workshop_mcp/complexity_analysis/__init__.py @@ -0,0 +1,17 @@ +"""Complexity analysis tools for measuring Python code complexity metrics.""" + +__version__ = "0.1.0" + +from .calculator import CognitiveCalculator, CyclomaticCalculator +from .metrics import ClassMetrics, FileMetrics, FunctionMetrics, analyze_complexity +from .patterns import ComplexityCategory + +__all__ = [ + "CyclomaticCalculator", + "CognitiveCalculator", + "FunctionMetrics", + "ClassMetrics", + "FileMetrics", + "ComplexityCategory", + "analyze_complexity", +] diff --git a/src/workshop_mcp/complexity_analysis/calculator.py b/src/workshop_mcp/complexity_analysis/calculator.py new file mode 100644 index 0000000..83158b2 --- /dev/null +++ b/src/workshop_mcp/complexity_analysis/calculator.py @@ -0,0 +1,120 @@ +"""Cyclomatic and cognitive complexity calculators using Astroid.""" + +import astroid + + +class CyclomaticCalculator: + """Calculates cyclomatic complexity for Python functions. + + Cyclomatic complexity counts the number of linearly independent paths + through a function. Higher values indicate more complex branching logic. + """ + + def calculate(self, node: astroid.FunctionDef | astroid.AsyncFunctionDef) -> int: + """Calculate cyclomatic complexity for a function node. + + Args: + node: An Astroid FunctionDef or AsyncFunctionDef node. + + Returns: + Cyclomatic complexity score (minimum 1). + """ + complexity = 1 # Base complexity + complexity += self._count_branches(node) + return complexity + + def _count_branches(self, node: astroid.NodeNG) -> int: + """Recursively count branching constructs.""" + count = 0 + for _child in node.nodes_of_class( + ( + astroid.If, + astroid.For, + astroid.While, + astroid.ExceptHandler, + astroid.With, + astroid.Assert, + astroid.IfExp, + astroid.Comprehension, + ) + ): + count += 1 + + # Count boolean operators in conditions + for bool_op in node.nodes_of_class(astroid.BoolOp): + # Each 'and'/'or' adds a new path + count += len(bool_op.values) - 1 + + return count + + +class CognitiveCalculator: + """Calculates cognitive complexity (Sonar's metric) for Python functions. + + Cognitive complexity measures how difficult code is to understand, + applying nesting penalties for structures inside other structures. + """ + + def calculate(self, node: astroid.FunctionDef | astroid.AsyncFunctionDef) -> int: + """Calculate cognitive complexity for a function node. + + Args: + node: An Astroid FunctionDef or AsyncFunctionDef node. + + Returns: + Cognitive complexity score (minimum 0). + """ + return self._walk(node, nesting=0, func_name=node.name) + + def _walk(self, node: astroid.NodeNG, nesting: int, func_name: str) -> int: + """Recursively walk the AST accumulating cognitive complexity.""" + total = 0 + + for child in node.get_children(): + if isinstance(child, (astroid.FunctionDef, astroid.AsyncFunctionDef)): + # Nested function definitions increase nesting + total += self._walk(child, nesting + 1, func_name) + continue + + # Increment for breaks in linear flow + nesting penalty + if isinstance(child, astroid.If): + total += 1 + nesting # +1 for if + nesting penalty + total += self._walk(child, nesting + 1, func_name) + continue + elif isinstance(child, (astroid.For, astroid.While)): + total += 1 + nesting + total += self._walk(child, nesting + 1, func_name) + continue + elif isinstance(child, astroid.ExceptHandler): + total += 1 + nesting + total += self._walk(child, nesting + 1, func_name) + continue + elif isinstance(child, astroid.With): + total += 1 + nesting + total += self._walk(child, nesting + 1, func_name) + continue + elif isinstance(child, astroid.IfExp): + total += 1 + nesting + total += self._walk(child, nesting, func_name) + continue + + # Boolean operators: +1 for each sequence + if isinstance(child, astroid.BoolOp): + total += 1 + + # Recursion: +1 when function calls itself + if isinstance(child, astroid.Call): + call_name = self._get_call_name(child) + if call_name == func_name: + total += 1 + + total += self._walk(child, nesting, func_name) + + return total + + @staticmethod + def _get_call_name(node: astroid.Call) -> str | None: + """Get the simple name of a function call.""" + if isinstance(node.func, astroid.Name): + return node.func.name + return None diff --git a/src/workshop_mcp/complexity_analysis/metrics.py b/src/workshop_mcp/complexity_analysis/metrics.py new file mode 100644 index 0000000..603bbc0 --- /dev/null +++ b/src/workshop_mcp/complexity_analysis/metrics.py @@ -0,0 +1,333 @@ +"""Function and class metric collectors for complexity analysis.""" + +from dataclasses import dataclass, field + +import astroid + +from ..core.ast_utils import extract_classes, extract_functions, parse_source +from .calculator import CognitiveCalculator, CyclomaticCalculator +from .patterns import ( + DEFAULT_COGNITIVE_THRESHOLD, + DEFAULT_CYCLOMATIC_THRESHOLD, + DEFAULT_MAX_CLASS_METHODS, + DEFAULT_MAX_FUNCTION_LENGTH, + DEFAULT_MAX_INHERITANCE_DEPTH, + DEFAULT_MAX_NESTING_DEPTH, + DEFAULT_MAX_PARAMETERS, + ComplexityCategory, + severity_for_cognitive, + severity_for_cyclomatic, +) + + +@dataclass +class FunctionMetrics: + """Aggregated metrics for a single function.""" + + name: str + line: int + end_line: int + cyclomatic: int + cognitive: int + length: int + params: int + nesting_depth: int + + +@dataclass +class ClassMetrics: + """Aggregated metrics for a single class.""" + + name: str + line: int + method_count: int + inheritance_depth: int + + +@dataclass +class FileMetrics: + """File-level summary statistics.""" + + total_functions: int + average_complexity: float + max_complexity: int + complex_functions: int # functions exceeding cyclomatic threshold + + +@dataclass +class ComplexityIssue: + """A complexity issue reported by the analyzer.""" + + tool: str + category: str + severity: str + message: str + line: int + function: str | None = None + metrics: dict | None = None + suggestion: str | None = None + + +@dataclass +class ComplexityResult: + """Full result from complexity analysis.""" + + issues: list[ComplexityIssue] = field(default_factory=list) + file_metrics: FileMetrics | None = None + function_metrics: list[FunctionMetrics] = field(default_factory=list) + class_metrics: list[ClassMetrics] = field(default_factory=list) + + +def analyze_complexity( + source_code: str, + *, + file_path: str | None = None, + cyclomatic_threshold: int = DEFAULT_CYCLOMATIC_THRESHOLD, + cognitive_threshold: int = DEFAULT_COGNITIVE_THRESHOLD, + max_function_length: int = DEFAULT_MAX_FUNCTION_LENGTH, +) -> ComplexityResult: + """Analyze complexity metrics for Python source code. + + Args: + source_code: Python source code to analyze. + file_path: Optional file path for context. + cyclomatic_threshold: Threshold for cyclomatic complexity warnings. + cognitive_threshold: Threshold for cognitive complexity warnings. + max_function_length: Maximum function length before warning. + + Returns: + ComplexityResult with issues, function metrics, and file-level summary. + + Raises: + SyntaxError: If source code cannot be parsed. + """ + tree = parse_source(source_code, file_path) + result = ComplexityResult() + + cyclo_calc = CyclomaticCalculator() + cog_calc = CognitiveCalculator() + + func_infos = extract_functions(tree) + func_nodes = list(tree.nodes_of_class((astroid.FunctionDef, astroid.AsyncFunctionDef))) + + # Build a map from (name, line) to AST node for metric calculation + node_map: dict[tuple[str, int], astroid.FunctionDef | astroid.AsyncFunctionDef] = {} + for node in func_nodes: + node_map[(node.name, node.lineno)] = node + + cyclomatic_scores: list[int] = [] + + for func_info in func_infos: + node = node_map.get((func_info.name, func_info.line_number)) + if node is None: + continue + + cyclomatic = cyclo_calc.calculate(node) + cognitive = cog_calc.calculate(node) + length = func_info.end_line_number - func_info.line_number + 1 + params = len(func_info.parameters) + nesting = _max_nesting_depth(node) + + fm = FunctionMetrics( + name=func_info.name, + line=func_info.line_number, + end_line=func_info.end_line_number, + cyclomatic=cyclomatic, + cognitive=cognitive, + length=length, + params=params, + nesting_depth=nesting, + ) + result.function_metrics.append(fm) + cyclomatic_scores.append(cyclomatic) + + metrics_dict = { + "cyclomatic": cyclomatic, + "cognitive": cognitive, + "lines": length, + "params": params, + "nesting_depth": nesting, + } + + # Check cyclomatic threshold + if cyclomatic > cyclomatic_threshold: + sev = severity_for_cyclomatic(cyclomatic, cyclomatic_threshold) + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.HIGH_CYCLOMATIC_COMPLEXITY.value, + severity=sev, + message=( + f"Function '{func_info.name}' has cyclomatic complexity " + f"of {cyclomatic} (threshold: {cyclomatic_threshold})" + ), + line=func_info.line_number, + function=func_info.name, + metrics=metrics_dict, + suggestion="Consider breaking this function into smaller, focused functions", + ) + ) + + # Check cognitive threshold + if cognitive > cognitive_threshold: + sev = severity_for_cognitive(cognitive, cognitive_threshold) + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.HIGH_COGNITIVE_COMPLEXITY.value, + severity=sev, + message=( + f"Function '{func_info.name}' has cognitive complexity " + f"of {cognitive} (threshold: {cognitive_threshold})" + ), + line=func_info.line_number, + function=func_info.name, + metrics=metrics_dict, + suggestion="Reduce nesting and simplify conditional logic", + ) + ) + + # Check function length + if length > max_function_length: + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.LONG_FUNCTION.value, + severity="warning", + message=( + f"Function '{func_info.name}' is {length} lines long " + f"(threshold: {max_function_length})" + ), + line=func_info.line_number, + function=func_info.name, + metrics=metrics_dict, + suggestion="Extract logic into helper functions", + ) + ) + + # Check parameter count + if params > DEFAULT_MAX_PARAMETERS: + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.TOO_MANY_PARAMETERS.value, + severity="warning", + message=( + f"Function '{func_info.name}' has {params} parameters " + f"(threshold: {DEFAULT_MAX_PARAMETERS})" + ), + line=func_info.line_number, + function=func_info.name, + metrics=metrics_dict, + suggestion="Group related parameters into a dataclass or dict", + ) + ) + + # Check nesting depth + if nesting > DEFAULT_MAX_NESTING_DEPTH: + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.DEEP_NESTING.value, + severity="warning", + message=( + f"Function '{func_info.name}' has nesting depth of {nesting} " + f"(threshold: {DEFAULT_MAX_NESTING_DEPTH})" + ), + line=func_info.line_number, + function=func_info.name, + metrics=metrics_dict, + suggestion="Use early returns or extract nested blocks into functions", + ) + ) + + # Class metrics + class_infos = extract_classes(tree) + class_nodes = list(tree.nodes_of_class(astroid.ClassDef)) + class_node_map: dict[tuple[str, int], astroid.ClassDef] = {} + for cn in class_nodes: + class_node_map[(cn.name, cn.lineno)] = cn + + for class_info in class_infos: + cn = class_node_map.get((class_info.name, class_info.line_number)) + inheritance_depth = _inheritance_depth(cn) if cn else 0 + method_count = len(class_info.methods) + + cm = ClassMetrics( + name=class_info.name, + line=class_info.line_number, + method_count=method_count, + inheritance_depth=inheritance_depth, + ) + result.class_metrics.append(cm) + + if method_count > DEFAULT_MAX_CLASS_METHODS: + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.LARGE_CLASS.value, + severity="warning", + message=( + f"Class '{class_info.name}' has {method_count} methods " + f"(threshold: {DEFAULT_MAX_CLASS_METHODS})" + ), + line=class_info.line_number, + function=None, + suggestion="Consider splitting into smaller, focused classes", + ) + ) + + if inheritance_depth > DEFAULT_MAX_INHERITANCE_DEPTH: + result.issues.append( + ComplexityIssue( + tool="complexity", + category=ComplexityCategory.DEEP_INHERITANCE.value, + severity="warning", + message=( + f"Class '{class_info.name}' has inheritance depth of " + f"{inheritance_depth} (threshold: {DEFAULT_MAX_INHERITANCE_DEPTH})" + ), + line=class_info.line_number, + function=None, + suggestion="Prefer composition over deep inheritance hierarchies", + ) + ) + + # File-level summary + total = len(cyclomatic_scores) + avg = sum(cyclomatic_scores) / total if total else 0.0 + max_c = max(cyclomatic_scores, default=0) + complex_count = sum(1 for s in cyclomatic_scores if s > cyclomatic_threshold) + + result.file_metrics = FileMetrics( + total_functions=total, + average_complexity=round(avg, 2), + max_complexity=max_c, + complex_functions=complex_count, + ) + + return result + + +def _max_nesting_depth(node: astroid.NodeNG, current: int = 0) -> int: + """Calculate the maximum nesting depth inside a function.""" + max_depth = current + + for child in node.get_children(): + if isinstance(child, (astroid.If, astroid.For, astroid.While, astroid.With, astroid.Try)): + child_depth = _max_nesting_depth(child, current + 1) + max_depth = max(max_depth, child_depth) + else: + child_depth = _max_nesting_depth(child, current) + max_depth = max(max_depth, child_depth) + + return max_depth + + +def _inheritance_depth(node: astroid.ClassDef) -> int: + """Calculate the inheritance depth of a class.""" + try: + ancestors = list(node.ancestors()) + return len(ancestors) if ancestors else 0 + except (astroid.InferenceError, StopIteration, RecursionError): + return len(node.bases) diff --git a/src/workshop_mcp/complexity_analysis/patterns.py b/src/workshop_mcp/complexity_analysis/patterns.py new file mode 100644 index 0000000..8aec472 --- /dev/null +++ b/src/workshop_mcp/complexity_analysis/patterns.py @@ -0,0 +1,87 @@ +"""Complexity categories, thresholds, and severity mapping.""" + +from enum import Enum + + +class ComplexityCategory(Enum): + """Categories of complexity issues.""" + + HIGH_CYCLOMATIC_COMPLEXITY = "high_cyclomatic_complexity" + HIGH_COGNITIVE_COMPLEXITY = "high_cognitive_complexity" + LONG_FUNCTION = "long_function" + TOO_MANY_PARAMETERS = "too_many_parameters" + DEEP_NESTING = "deep_nesting" + LARGE_CLASS = "large_class" + DEEP_INHERITANCE = "deep_inheritance" + + +# Cyclomatic complexity thresholds +CYCLOMATIC_SIMPLE = 10 +CYCLOMATIC_MODERATE = 20 +CYCLOMATIC_HIGH = 50 + +# Default thresholds +DEFAULT_CYCLOMATIC_THRESHOLD = 10 +DEFAULT_COGNITIVE_THRESHOLD = 15 +DEFAULT_MAX_FUNCTION_LENGTH = 50 +DEFAULT_MAX_PARAMETERS = 5 +DEFAULT_MAX_NESTING_DEPTH = 4 +DEFAULT_MAX_CLASS_METHODS = 20 +DEFAULT_MAX_INHERITANCE_DEPTH = 3 + + +def cyclomatic_label(score: int) -> str: + """Return a human-readable label for a cyclomatic complexity score. + + Args: + score: Cyclomatic complexity value. + + Returns: + One of 'simple', 'moderate', 'high', or 'very high'. + """ + if score <= CYCLOMATIC_SIMPLE: + return "simple" + if score <= CYCLOMATIC_MODERATE: + return "moderate" + if score <= CYCLOMATIC_HIGH: + return "high" + return "very high" + + +def severity_for_cyclomatic(score: int, threshold: int) -> str: + """Return severity string based on how far above threshold. + + Args: + score: Cyclomatic complexity value. + threshold: Configured threshold. + + Returns: + Severity string: 'info', 'warning', 'error', or 'critical'. + """ + if score <= threshold: + return "info" + if score <= CYCLOMATIC_MODERATE: + return "warning" + if score <= CYCLOMATIC_HIGH: + return "error" + return "critical" + + +def severity_for_cognitive(score: int, threshold: int) -> str: + """Return severity string for cognitive complexity. + + Args: + score: Cognitive complexity value. + threshold: Configured threshold. + + Returns: + Severity string. + """ + if score <= threshold: + return "info" + ratio = score / max(threshold, 1) + if ratio <= 2.0: + return "warning" + if ratio <= 3.0: + return "error" + return "critical" diff --git a/src/workshop_mcp/server.py b/src/workshop_mcp/server.py index aa6a760..6c9d887 100644 --- a/src/workshop_mcp/server.py +++ b/src/workshop_mcp/server.py @@ -10,9 +10,11 @@ import json import logging import sys -from dataclasses import dataclass +from dataclasses import asdict, dataclass +from pathlib import Path from typing import Any +from .complexity_analysis import analyze_complexity from .keyword_search import KeywordSearchTool from .logging_context import CorrelationIdFilter, correlation_id_var, request_context from .performance_profiler import PerformanceChecker @@ -290,6 +292,50 @@ def _handle_list_tools(self, request_id: Any) -> dict[str, Any]: ], }, }, + { + "name": "complexity_analysis", + "description": ( + "Analyze Python code for complexity metrics including cyclomatic " + "complexity, cognitive complexity, function length, parameter count, " + "nesting depth, and class metrics. Provides actionable suggestions " + "for reducing complexity." + ), + "inputSchema": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the Python file to analyze", + "minLength": 1, + }, + "source_code": { + "type": "string", + "description": ( + "Python source code string to analyze instead of file" + ), + }, + "cyclomatic_threshold": { + "type": "integer", + "description": ("Cyclomatic complexity threshold (default 10)"), + "default": 10, + }, + "cognitive_threshold": { + "type": "integer", + "description": ("Cognitive complexity threshold (default 15)"), + "default": 15, + }, + "max_function_length": { + "type": "integer", + "description": ("Maximum function length in lines (default 50)"), + "default": 50, + }, + }, + "oneOf": [ + {"required": ["file_path"]}, + {"required": ["source_code"]}, + ], + }, + }, ] } return self._success_response(request_id, result) @@ -308,6 +354,8 @@ def _handle_call_tool(self, request_id: Any, params: dict[str, Any] | None) -> d return self._execute_keyword_search(request_id, arguments) elif name == "performance_check": return self._execute_performance_check(request_id, arguments) + elif name == "complexity_analysis": + return self._execute_complexity_analysis(request_id, arguments) else: return self._error_response( request_id, @@ -571,6 +619,121 @@ def _execute_performance_check( ), ) + def _execute_complexity_analysis( + self, request_id: Any, arguments: dict[str, Any] + ) -> dict[str, Any]: + if not isinstance(arguments, dict): + return self._error_response( + request_id, + JsonRpcError(-32602, "Invalid params", {"expected": "object"}), + ) + + file_path = arguments.get("file_path") + source_code = arguments.get("source_code") + + if not file_path and not source_code: + return self._error_response( + request_id, + JsonRpcError(-32602, "Either file_path or source_code must be provided"), + ) + if file_path and source_code: + return self._error_response( + request_id, + JsonRpcError(-32602, "Provide only one of file_path or source_code"), + ) + + if file_path is not None and not isinstance(file_path, str): + return self._error_response( + request_id, + JsonRpcError(-32602, "file_path must be a string"), + ) + + if file_path: + try: + self.path_validator.validate_exists(file_path, must_be_file=True) + except PathValidationError as e: + return self._error_response( + request_id, + JsonRpcError(-32602, str(e)), + ) + + cyclomatic_threshold = arguments.get("cyclomatic_threshold", 10) + cognitive_threshold = arguments.get("cognitive_threshold", 15) + max_function_length = arguments.get("max_function_length", 50) + + try: + logger.info( + "Executing complexity analysis on %s", + file_path or "source code", + ) + + # Read source from file if needed + if file_path and not source_code: + source_code = Path(file_path).read_text(encoding="utf-8") + + complexity_result = analyze_complexity( + source_code, + file_path=file_path, + cyclomatic_threshold=cyclomatic_threshold, + cognitive_threshold=cognitive_threshold, + max_function_length=max_function_length, + ) + + issues_data = [asdict(issue) for issue in complexity_result.issues] + file_metrics_data = ( + asdict(complexity_result.file_metrics) if complexity_result.file_metrics else {} + ) + + result = { + "content": [ + { + "type": "json", + "json": { + "success": True, + "file_analyzed": file_path or "source_code", + "issues": issues_data, + "file_metrics": file_metrics_data, + }, + } + ], + } + return self._success_response(request_id, result) + + except ValueError as exc: + logger.warning("ValueError in complexity_analysis: %s", exc) + return self._error_response( + request_id, + JsonRpcError(-32602, "Invalid parameters"), + ) + except FileNotFoundError as exc: + logger.warning("FileNotFoundError in complexity_analysis: %s", exc) + return self._error_response( + request_id, + JsonRpcError(-32602, "Resource not found"), + ) + except SyntaxError as exc: + logger.warning("SyntaxError in complexity_analysis: %s", exc) + return self._error_response( + request_id, + JsonRpcError(-32602, "Invalid source code syntax"), + ) + except SecurityValidationError as exc: + logger.warning("Security validation error: %s", exc) + return self._error_response( + request_id, + JsonRpcError(-32602, str(exc)), + ) + except Exception: + logger.exception("Error executing complexity_analysis") + return self._error_response( + request_id, + JsonRpcError( + -32603, + "Internal error", + {"correlation_id": correlation_id_var.get()}, + ), + ) + def _success_response(self, request_id: Any, result: dict[str, Any]) -> dict[str, Any]: return {"jsonrpc": JSONRPC_VERSION, "id": request_id, "result": result} diff --git a/tests/test_complexity_analysis.py b/tests/test_complexity_analysis.py new file mode 100644 index 0000000..65bc28e --- /dev/null +++ b/tests/test_complexity_analysis.py @@ -0,0 +1,703 @@ +"""Tests for the complexity analysis module.""" + +import pytest + +from workshop_mcp.complexity_analysis.calculator import ( + CognitiveCalculator, + CyclomaticCalculator, +) +from workshop_mcp.complexity_analysis.metrics import analyze_complexity +from workshop_mcp.complexity_analysis.patterns import ( + ComplexityCategory, + cyclomatic_label, + severity_for_cognitive, + severity_for_cyclomatic, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _parse_function(source: str): + """Parse source and return the first FunctionDef node.""" + import astroid + + tree = astroid.parse(source) + for node in tree.nodes_of_class((astroid.FunctionDef, astroid.AsyncFunctionDef)): + return node + raise ValueError("No function found in source") + + +# =========================================================================== +# CyclomaticCalculator tests +# =========================================================================== + + +class TestCyclomaticCalculator: + """Test cyclomatic complexity calculation.""" + + def test_simple_function(self): + """A function with no branching has complexity 1.""" + source = """ +def simple(): + x = 1 + return x +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 1 + + def test_single_if(self): + """A single if statement adds 1.""" + source = """ +def fn(x): + if x > 0: + return x + return -x +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 + + def test_if_elif_else(self): + """if + elif is 2 extra branches.""" + source = """ +def fn(x): + if x > 0: + return 1 + elif x == 0: + return 0 + else: + return -1 +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 3 # 1 base + if + elif + + def test_for_loop(self): + """A for loop adds 1.""" + source = """ +def fn(items): + total = 0 + for item in items: + total += item + return total +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 + + def test_while_loop(self): + """A while loop adds 1.""" + source = """ +def fn(): + i = 0 + while i < 10: + i += 1 + return i +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 + + def test_except_handler(self): + """Each except handler adds 1.""" + source = """ +def fn(): + try: + return 1 / 0 + except ZeroDivisionError: + return 0 + except ValueError: + return -1 +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 3 # 1 base + 2 except + + def test_with_statement(self): + """A with statement adds 1.""" + source = """ +def fn(): + with open('f') as fh: + return fh.read() +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 + + def test_assert_statement(self): + """An assert adds 1.""" + source = """ +def fn(x): + assert x > 0 + return x +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 + + def test_boolean_and(self): + """Boolean 'and' adds 1 per operator.""" + source = """ +def fn(a, b): + if a and b: + return True +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + # 1 base + 1 if + 1 and + assert calc.calculate(node) == 3 + + def test_boolean_or(self): + """Boolean 'or' adds 1 per operator.""" + source = """ +def fn(a, b, c): + if a or b or c: + return True +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + # 1 base + 1 if + 2 or (three operands = 2 operators) + assert calc.calculate(node) == 4 + + def test_comprehension(self): + """List comprehension adds 1.""" + source = """ +def fn(items): + return [x for x in items] +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + # 1 base + 1 comprehension + assert calc.calculate(node) >= 2 + + def test_ternary_expression(self): + """Ternary (IfExp) adds 1.""" + source = """ +def fn(x): + return x if x > 0 else -x +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + assert calc.calculate(node) == 2 # 1 base + 1 IfExp + + def test_complex_function(self): + """A function with many branches has high complexity.""" + source = """ +def process(data, flag, mode): + if not data: + return None + result = [] + for item in data: + if flag and item > 0: + if mode == 'add': + result.append(item) + elif mode == 'double': + result.append(item * 2) + else: + result.append(0) + elif item == 0: + continue + else: + try: + result.append(-item) + except TypeError: + pass + while len(result) > 100: + result.pop() + return result +""" + node = _parse_function(source) + calc = CyclomaticCalculator() + score = calc.calculate(node) + assert score >= 10 + + +# =========================================================================== +# CognitiveCalculator tests +# =========================================================================== + + +class TestCognitiveCalculator: + """Test cognitive complexity calculation.""" + + def test_simple_function(self): + """No flow breaks = 0 cognitive complexity.""" + source = """ +def simple(): + return 42 +""" + node = _parse_function(source) + calc = CognitiveCalculator() + assert calc.calculate(node) == 0 + + def test_single_if(self): + """Single if at top level adds 1.""" + source = """ +def fn(x): + if x: + return x +""" + node = _parse_function(source) + calc = CognitiveCalculator() + assert calc.calculate(node) == 1 + + def test_nesting_penalty(self): + """Nested if adds 1 + nesting level.""" + source = """ +def fn(x, y): + if x: + if y: + return True +""" + node = _parse_function(source) + calc = CognitiveCalculator() + # outer if: +1 (nesting=0) + # inner if: +1 + 1 (nesting=1) + assert calc.calculate(node) == 3 + + def test_loop_with_nested_if(self): + """for loop + nested if both get nesting penalties.""" + source = """ +def fn(items): + for item in items: + if item > 0: + pass +""" + node = _parse_function(source) + calc = CognitiveCalculator() + # for: +1 (nesting=0) + # if inside for: +1 + 1 (nesting=1) + assert calc.calculate(node) == 3 + + def test_boolean_operator(self): + """Boolean operators add 1 per operator group.""" + source = """ +def fn(a, b): + if a and b: + return True +""" + node = _parse_function(source) + calc = CognitiveCalculator() + # if: +1, and: +1 + assert calc.calculate(node) == 2 + + def test_recursion_detection(self): + """Recursive call adds 1.""" + source = """ +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) +""" + node = _parse_function(source) + calc = CognitiveCalculator() + # if: +1, recursion: +1 + assert calc.calculate(node) == 2 + + def test_deeply_nested(self): + """Deeply nested code has high cognitive complexity.""" + source = """ +def fn(a, b, c, d): + if a: + for x in b: + if c: + while d: + pass +""" + node = _parse_function(source) + calc = CognitiveCalculator() + # if(0): +1, for(1): +1+1=2, if(2): +1+2=3, while(3): +1+3=4 + assert calc.calculate(node) == 10 + + +# =========================================================================== +# analyze_complexity integration tests +# =========================================================================== + + +class TestAnalyzeComplexity: + """Test the main analyze_complexity function.""" + + def test_simple_code_no_issues(self): + """Simple code should produce no issues.""" + source = """ +def add(a, b): + return a + b +""" + result = analyze_complexity(source) + assert len(result.issues) == 0 + assert result.file_metrics is not None + assert result.file_metrics.total_functions == 1 + assert result.file_metrics.max_complexity == 1 + + def test_invalid_syntax(self): + """Invalid Python raises SyntaxError.""" + with pytest.raises(SyntaxError): + analyze_complexity("def broken(") + + def test_high_cyclomatic_detected(self): + """Functions above the cyclomatic threshold produce an issue.""" + source = """ +def complex_fn(a, b, c, d, e): + if a: + pass + if b: + pass + if c: + pass + if d: + pass + if e: + pass + for x in range(10): + if x > 5: + pass + while a: + break + try: + pass + except ValueError: + pass + except TypeError: + pass + assert a +""" + result = analyze_complexity(source, cyclomatic_threshold=5) + cats = [i.category for i in result.issues] + assert ComplexityCategory.HIGH_CYCLOMATIC_COMPLEXITY.value in cats + + def test_high_cognitive_detected(self): + """Functions above the cognitive threshold produce an issue.""" + source = """ +def deeply_nested(a, b, c, d): + if a: + for x in b: + if c: + while d: + if a and b: + pass +""" + result = analyze_complexity(source, cognitive_threshold=5) + cats = [i.category for i in result.issues] + assert ComplexityCategory.HIGH_COGNITIVE_COMPLEXITY.value in cats + + def test_long_function_detected(self): + """Functions exceeding max_function_length produce an issue.""" + lines = [" x = 1"] * 60 + source = "def long_fn():\n" + "\n".join(lines) + "\n return x\n" + result = analyze_complexity(source, max_function_length=50) + cats = [i.category for i in result.issues] + assert ComplexityCategory.LONG_FUNCTION.value in cats + + def test_too_many_parameters_detected(self): + """Functions with too many parameters produce an issue.""" + source = """ +def many_params(a, b, c, d, e, f, g): + return a + b + c + d + e + f + g +""" + result = analyze_complexity(source) + cats = [i.category for i in result.issues] + assert ComplexityCategory.TOO_MANY_PARAMETERS.value in cats + + def test_self_cls_not_counted(self): + """self and cls should be counted by the raw param count. + + Note: The core ast_utils extract_functions includes self/cls in + parameters, so the param check in metrics.py counts all params. + A method with (self, a, b, c, d, e) has 6 params total which + exceeds the default threshold of 5. + """ + source = """ +class Foo: + def method(self, a, b, c, d): + pass +""" + result = analyze_complexity(source) + # self + 4 params = 5 total, at threshold, should not trigger + param_issues = [ + i for i in result.issues if i.category == ComplexityCategory.TOO_MANY_PARAMETERS.value + ] + assert len(param_issues) == 0 + + def test_deep_nesting_detected(self): + """Deeply nested code produces an issue.""" + source = """ +def deep(): + if True: + for x in []: + if True: + while True: + if True: + pass +""" + result = analyze_complexity(source, cyclomatic_threshold=100) + cats = [i.category for i in result.issues] + assert ComplexityCategory.DEEP_NESTING.value in cats + + def test_class_metrics(self): + """Classes with many methods are detected.""" + methods = "\n".join(f" def method_{i}(self):\n pass\n" for i in range(25)) + source = f"class BigClass:\n{methods}" + result = analyze_complexity(source, max_function_length=100) + cats = [i.category for i in result.issues] + assert ComplexityCategory.LARGE_CLASS.value in cats + + def test_threshold_configurability(self): + """Lowering thresholds catches more issues.""" + source = """ +def fn(x): + if x: + return 1 + return 0 +""" + # Default threshold (10) should not trigger + result_default = analyze_complexity(source) + assert len(result_default.issues) == 0 + + # Threshold of 1 should trigger + result_low = analyze_complexity(source, cyclomatic_threshold=1) + assert len(result_low.issues) > 0 + + def test_file_metrics_aggregation(self): + """File metrics correctly aggregate function-level data.""" + source = """ +def fn1(): + pass + +def fn2(x): + if x: + return 1 + return 0 + +def fn3(a, b, c): + if a: + if b: + if c: + return True + return False +""" + result = analyze_complexity(source) + fm = result.file_metrics + assert fm is not None + assert fm.total_functions == 3 + assert fm.average_complexity > 0 + assert fm.max_complexity >= 1 + + def test_empty_function(self): + """Empty/pass function has complexity 1.""" + source = """ +def noop(): + pass +""" + result = analyze_complexity(source) + assert result.file_metrics is not None + assert result.file_metrics.max_complexity == 1 + assert len(result.issues) == 0 + + def test_single_line_function(self): + """Single-expression function has complexity 1.""" + source = """ +def identity(x): return x +""" + result = analyze_complexity(source) + assert result.file_metrics is not None + assert result.file_metrics.max_complexity == 1 + + def test_file_path_analysis(self, tmp_path): + """Analyze a file by path (read source from file first).""" + test_file = tmp_path / "test.py" + source = "def fn(x):\n if x:\n return x\n return -x\n" + test_file.write_text(source) + result = analyze_complexity(source, file_path=str(test_file)) + assert result.file_metrics is not None + assert result.file_metrics.total_functions == 1 + + def test_issue_output_format(self): + """Issues have the expected fields.""" + source = """ +def complex_fn(a, b, c, d, e, f, g): + if a and b: + for x in c: + if d: + while e: + try: + pass + except Exception: + pass + if f: + pass + if g: + pass +""" + result = analyze_complexity(source, cyclomatic_threshold=5) + assert len(result.issues) > 0 + issue = result.issues[0] + assert issue.tool == "complexity" + assert issue.category is not None + assert issue.severity is not None + assert issue.message is not None + assert issue.line is not None + assert issue.suggestion is not None + + +# =========================================================================== +# Patterns / thresholds tests +# =========================================================================== + + +class TestPatterns: + """Test threshold helpers.""" + + def test_cyclomatic_labels(self): + assert cyclomatic_label(5) == "simple" + assert cyclomatic_label(10) == "simple" + assert cyclomatic_label(15) == "moderate" + assert cyclomatic_label(20) == "moderate" + assert cyclomatic_label(30) == "high" + assert cyclomatic_label(50) == "high" + assert cyclomatic_label(51) == "very high" + + def test_severity_for_cyclomatic(self): + assert severity_for_cyclomatic(5, 10) == "info" + assert severity_for_cyclomatic(15, 10) == "warning" + assert severity_for_cyclomatic(25, 10) == "error" + assert severity_for_cyclomatic(55, 10) == "critical" + + def test_severity_for_cognitive(self): + assert severity_for_cognitive(10, 15) == "info" + assert severity_for_cognitive(20, 15) == "warning" + assert severity_for_cognitive(40, 15) == "error" + assert severity_for_cognitive(60, 15) == "critical" + + +# =========================================================================== +# MCP server integration tests +# =========================================================================== + + +class TestMCPIntegration: + """Test complexity_analysis tool via MCP server.""" + + def test_tool_listed(self): + """complexity_analysis appears in list_tools.""" + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + response = server._handle_request({"jsonrpc": "2.0", "id": 1, "method": "list_tools"}) + tools = response["result"]["tools"] + tool_names = [t["name"] for t in tools] + assert "complexity_analysis" in tool_names + + def test_tool_schema(self): + """complexity_analysis has the expected input schema.""" + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + response = server._handle_request({"jsonrpc": "2.0", "id": 1, "method": "list_tools"}) + tools = response["result"]["tools"] + tool = next(t for t in tools if t["name"] == "complexity_analysis") + props = tool["inputSchema"]["properties"] + assert "file_path" in props + assert "source_code" in props + assert "cyclomatic_threshold" in props + assert "cognitive_threshold" in props + assert "max_function_length" in props + + def test_tool_callable_with_source_code(self): + """complexity_analysis returns results when called with source_code.""" + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + source = """ +def fn(x): + if x > 0: + return x + return -x +""" + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "call_tool", + "params": { + "name": "complexity_analysis", + "arguments": {"source_code": source}, + }, + } + response = server._handle_request(request) + assert "result" in response + result_json = response["result"]["content"][0]["json"] + assert result_json["success"] is True + assert "issues" in result_json + assert "file_metrics" in result_json + assert result_json["file_metrics"]["total_functions"] == 1 + + def test_tool_callable_with_file_path(self, tmp_path, monkeypatch): + """complexity_analysis works with a file path.""" + monkeypatch.setenv("MCP_ALLOWED_ROOTS", str(tmp_path)) + + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + test_file = tmp_path / "test.py" + test_file.write_text("def fn():\n pass\n") + + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "call_tool", + "params": { + "name": "complexity_analysis", + "arguments": {"file_path": str(test_file)}, + }, + } + response = server._handle_request(request) + assert "result" in response + result_json = response["result"]["content"][0]["json"] + assert result_json["success"] is True + + def test_tool_error_no_input(self): + """complexity_analysis returns error when no input is given.""" + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "call_tool", + "params": { + "name": "complexity_analysis", + "arguments": {}, + }, + } + response = server._handle_request(request) + assert "error" in response + + def test_tool_custom_thresholds(self): + """complexity_analysis respects custom thresholds.""" + from workshop_mcp.server import WorkshopMCPServer + + server = WorkshopMCPServer() + source = """ +def fn(x): + if x: + return 1 + return 0 +""" + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "call_tool", + "params": { + "name": "complexity_analysis", + "arguments": { + "source_code": source, + "cyclomatic_threshold": 1, + }, + }, + } + response = server._handle_request(request) + result_json = response["result"]["content"][0]["json"] + assert len(result_json["issues"]) > 0