diff --git a/.github/workflows/test-toolchain-compat.yml b/.github/workflows/test-toolchain-compat.yml new file mode 100644 index 0000000000..1c00ea2845 --- /dev/null +++ b/.github/workflows/test-toolchain-compat.yml @@ -0,0 +1,25 @@ +name: Toolchain Compatibility + +on: + push: + branches: [master] + pull_request: + workflow_dispatch: + +jobs: + test-toolchain: + name: "Python ${{ matrix.python-version }}" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Initialize MFC + run: ./mfc.sh init + - name: Lint and test toolchain + run: ./mfc.sh lint diff --git a/CMakeLists.txt b/CMakeLists.txt index 33a196c215..9e3fef30d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -665,6 +665,7 @@ if (MFC_DOCUMENTATION) OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/docs/documentation/case_constraints.md" DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/gen_case_constraints_docs.py" "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/case_validator.py" + "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/params/definitions.py" "${examples_DOCs}" COMMAND "bash" "${CMAKE_CURRENT_SOURCE_DIR}/docs/gen_constraints.sh" "${CMAKE_CURRENT_SOURCE_DIR}" @@ -684,10 +685,12 @@ if (MFC_DOCUMENTATION) ) # Generate parameters.md from parameter registry + # docs_gen.py now AST-parses case_validator.py, so it must be a dependency file(GLOB_RECURSE params_SRCs CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/params/*.py") add_custom_command( OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/docs/documentation/parameters.md" DEPENDS "${params_SRCs}" + "${CMAKE_CURRENT_SOURCE_DIR}/toolchain/mfc/case_validator.py" COMMAND "bash" "${CMAKE_CURRENT_SOURCE_DIR}/docs/gen_parameters.sh" "${CMAKE_CURRENT_SOURCE_DIR}" COMMENT "Generating parameters.md" diff --git a/toolchain/mfc/case_validator.py b/toolchain/mfc/case_validator.py index fad1243111..0af9a02867 100644 --- a/toolchain/mfc/case_validator.py +++ b/toolchain/mfc/case_validator.py @@ -14,9 +14,11 @@ # pylint: disable=too-many-lines # Justification: Comprehensive validator covering all MFC parameter constraints +import re from typing import Dict, Any, List, Set from functools import lru_cache from .common import MFCException +from .params.definitions import CONSTRAINTS @lru_cache(maxsize=1) @@ -144,6 +146,10 @@ def check_model_eqns_and_num_fluids(self): def check_igr(self): """Checks constraints regarding IGR order""" igr = self.get('igr', 'F') == 'T' + igr_pres_lim = self.get('igr_pres_lim', 'F') == 'T' + + self.prohibit(igr_pres_lim and not igr, + "igr_pres_lim requires igr to be enabled") if not igr: return @@ -152,7 +158,6 @@ def check_igr(self): m = self.get('m', 0) n = self.get('n', 0) p = self.get('p', 0) - self.prohibit(igr_order not in [None, 3, 5], "igr_order must be 3 or 5") if igr_order: @@ -191,6 +196,10 @@ def check_weno(self): def check_muscl(self): """Check constraints regarding MUSCL order""" recon_type = self.get('recon_type', 1) + int_comp = self.get('int_comp', 'F') == 'T' + + self.prohibit(int_comp and recon_type != 2, + "int_comp (THINC interface compression) requires recon_type = 2 (MUSCL)") # MUSCL_TYPE = 2 if recon_type != 2: @@ -1201,6 +1210,10 @@ def check_probe_integral_output(self): def check_hyperelasticity(self): """Checks hyperelasticity constraints""" hyperelasticity = self.get('hyperelasticity', 'F') == 'T' + pre_stress = self.get('pre_stress', 'F') == 'T' + + self.prohibit(pre_stress and not hyperelasticity, + "pre_stress requires hyperelasticity to be enabled") if not hyperelasticity: return @@ -1911,14 +1924,27 @@ def _format_errors(self) -> str: err_lower = err.lower() if "must be positive" in err_lower or "must be set" in err_lower: lines.append(" [dim]Check that this required parameter is defined in your case file[/dim]") - elif "weno_order" in err_lower: - lines.append(" [dim]Valid values: 1, 3, 5, or 7[/dim]") - elif "riemann_solver" in err_lower: - lines.append(" [dim]Valid values: 1 (HLL), 2 (HLLC), 3 (Exact), etc.[/dim]") - elif "model_eqns" in err_lower: - lines.append(" [dim]Valid values: 1, 2 (5-eq), 3 (6-eq), or 4[/dim]") - elif "boundary" in err_lower or "bc_" in err_lower: + continue + if "boundary" in err_lower or "bc_" in err_lower: lines.append(" [dim]Common BC values: -1 (periodic), -2 (reflective), -3 (extrapolation)[/dim]") + continue + + # Auto-generate hints from CONSTRAINTS with value_labels + for param_name, constraint in CONSTRAINTS.items(): + if not re.search(r'\b' + re.escape(param_name.lower()) + r'\b', err_lower): + continue + choices = constraint.get("choices") + if not choices: + continue + labels = constraint.get("value_labels", {}) + if labels: + items = [f"{v} ({labels[v]})" if v in labels else str(v) + for v in choices] + hint = f"Valid values: {', '.join(items)}" + else: + hint = f"Valid values: {choices}" + lines.append(f" [dim]{hint}[/dim]") + break lines.append("") lines.append("[dim]Tip: Run './mfc.sh validate case.py' for detailed validation[/dim]") diff --git a/toolchain/mfc/gen_case_constraints_docs.py b/toolchain/mfc/gen_case_constraints_docs.py index b485f344c1..982c50da01 100644 --- a/toolchain/mfc/gen_case_constraints_docs.py +++ b/toolchain/mfc/gen_case_constraints_docs.py @@ -6,17 +6,16 @@ maps them to parameters and stages, and emits Markdown to stdout. Also generates case design playbook from curated working examples. -""" +""" # pylint: disable=too-many-lines from __future__ import annotations -import ast import json import sys import subprocess -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Set, Iterable, Any +from typing import Dict, List, Iterable, Any from collections import defaultdict HERE = Path(__file__).resolve().parent @@ -24,290 +23,16 @@ REPO_ROOT = HERE.parent.parent EXAMPLES_DIR = REPO_ROOT / "examples" +# Make the params package importable +_toolchain_dir = str(HERE.parent) +if _toolchain_dir not in sys.path: + sys.path.insert(0, _toolchain_dir) -# --------------------------------------------------------------------------- -# Data structures -# --------------------------------------------------------------------------- - -@dataclass -class Rule: - method: str # e.g. "check_igr_simulation" - lineno: int # line number of the prohibit call - params: List[str] # case parameter names used in condition - message: str # user-friendly error message - stages: Set[str] = field(default_factory=set) # e.g. {"simulation", "pre_process"} - - -# --------------------------------------------------------------------------- -# AST analysis: methods, call graph, rules -# --------------------------------------------------------------------------- - -class CaseValidatorAnalyzer(ast.NodeVisitor): - """ - Analyzes the CaseValidator class: - - - collects all methods - - builds a call graph between methods - - extracts all self.prohibit(...) rules - """ - - def __init__(self): - super().__init__() - self.in_case_validator = False - self.current_method: str | None = None - - self.methods: Dict[str, ast.FunctionDef] = {} - self.call_graph: Dict[str, Set[str]] = defaultdict(set) - self.rules: List[Rule] = [] - - # Stack of {local_name -> param_name} maps, one per method - self.local_param_stack: List[Dict[str, str]] = [] - - # --- top-level entrypoint --- - - def visit_ClassDef(self, node: ast.ClassDef): - if node.name == "CaseValidator": - self.in_case_validator = True - # collect methods - for item in node.body: - if isinstance(item, ast.FunctionDef): - self.methods[item.name] = item - # now analyze all methods - for method in self.methods.values(): - self._analyze_method(method) - self.in_case_validator = False - else: - self.generic_visit(node) - - # --- per-method analysis --- - - def _analyze_method(self, func: ast.FunctionDef): - """Analyze a single method: local param mapping, call graph, rules.""" - self.current_method = func.name - local_param_map = self._build_local_param_map(func) - self.local_param_stack.append(local_param_map) - self.generic_visit(func) - self.local_param_stack.pop() - self.current_method = None - - def _build_local_param_map(self, func: ast.FunctionDef) -> Dict[str, str]: # pylint: disable=too-many-nested-blocks - """ - Look for assignments like: - igr = self.get('igr', 'F') == 'T' - model_eqns = self.get('model_eqns') - and record local_name -> 'param_name'. - """ - m: Dict[str, str] = {} - for stmt in func.body: # pylint: disable=too-many-nested-blocks - if isinstance(stmt, ast.Assign): - # Handle both direct calls and comparisons - value = stmt.value - # Unwrap comparisons like "self.get('igr', 'F') == 'T'" - if isinstance(value, ast.Compare): - value = value.left - - if isinstance(value, ast.Call): - call = value - if ( # pylint: disable=too-many-boolean-expressions - isinstance(call.func, ast.Attribute) - and isinstance(call.func.value, ast.Name) - and call.func.value.id == "self" - and call.func.attr == "get" - and call.args - and isinstance(call.args[0], ast.Constant) - and isinstance(call.args[0].value, str) - ): - param_name = call.args[0].value - for target in stmt.targets: - if isinstance(target, ast.Name): - m[target.id] = param_name - return m - - # --- visit calls to build call graph + rules --- - - def visit_Call(self, node: ast.Call): - # record method call edges: self.some_method(...) - if ( - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "self" - and isinstance(node.func.attr, str) - ): - callee = node.func.attr - if self.current_method is not None: - # method call on self - self.call_graph[self.current_method].add(callee) - - # detect self.prohibit(, "") - if ( - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "self" - and node.func.attr == "prohibit" - and len(node.args) >= 2 - ): - condition, msg_node = node.args[0], node.args[1] - if isinstance(msg_node, ast.Constant) and isinstance(msg_node.value, str): - params = sorted(self._extract_params(condition)) - rule = Rule( - method=self.current_method or "", - lineno=node.lineno, - params=params, - message=msg_node.value, - ) - self.rules.append(rule) - - self.generic_visit(node) - - def _extract_params(self, condition: ast.AST) -> Set[str]: - """ - Collect parameter names used in the condition via: - - local variables mapped from self.get(...) - - direct self.get('param_name', ...) calls - """ - params: Set[str] = set() - local_map = self.local_param_stack[-1] if self.local_param_stack else {} - - for node in ast.walk(condition): - # local names - if isinstance(node, ast.Name) and node.id in local_map: - params.add(local_map[node.id]) - - # direct self.get('param_name') - if isinstance(node, ast.Call): - if ( # pylint: disable=too-many-boolean-expressions - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "self" - and node.func.attr == "get" - and node.args - and isinstance(node.args[0], ast.Constant) - and isinstance(node.args[0].value, str) - ): - params.add(node.args[0].value) - - return params - - -# --------------------------------------------------------------------------- -# Stage inference from validate_* roots and call graph -# --------------------------------------------------------------------------- - -STAGE_ROOTS: Dict[str, List[str]] = { - "common": ["validate_common"], - "simulation": ["validate_simulation"], - "pre_process": ["validate_pre_process"], - "post_process": ["validate_post_process"], -} - - -def compute_method_stages(call_graph: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - """ - For each stage (simulation/pre_process/post_process/common), starting from - validate_* roots, walk the call graph and record which methods belong to which stages. - """ - method_stages: Dict[str, Set[str]] = defaultdict(set) - - def dfs(start: str, stage: str): - stack = [start] - visited: Set[str] = set() - while stack: - m = stack.pop() - if m in visited: - continue - visited.add(m) - method_stages[m].add(stage) - for nxt in call_graph.get(m, ()): - if nxt not in visited: - stack.append(nxt) - - for stage, roots in STAGE_ROOTS.items(): - for root in roots: - dfs(root, stage) - - return method_stages - - -# --------------------------------------------------------------------------- -# Classification of messages for nicer grouping -# --------------------------------------------------------------------------- - -def classify_message(msg: str) -> str: - """ - Roughly classify rule messages for nicer grouping. - - Returns one of: "requirement", "incompatibility", "range", "other". - """ - text = msg.lower() - - if ( # pylint: disable=too-many-boolean-expressions - "not compatible" in text - or "does not support" in text - or "cannot be used" in text - or "must not" in text - or "is not supported" in text - or "incompatible" in text - or "untested" in text - ): - return "incompatibility" - - if ( # pylint: disable=too-many-boolean-expressions - "requires" in text - or "must be set if" in text - or "must be specified" in text - or "must be set with" in text - or "can only be enabled if" in text - or "must be set for" in text - ): - return "requirement" - - if ( # pylint: disable=too-many-boolean-expressions - "must be between" in text - or "must be positive" in text - or "must be non-negative" in text - or "must be greater than" in text - or "must be less than" in text - or "must be at least" in text - or "must be <=" in text - or "must be >=" in text - or "must be odd" in text - or "divisible by" in text - ): - return "range" - - return "other" - - -# Optional: nicer display names / categories (you can extend this) -FEATURE_META = { - "igr": {"title": "Iterative Generalized Riemann (IGR)", "category": "solver"}, - "bubbles_euler": {"title": "Euler–Euler Bubble Model", "category": "bubbles"}, - "bubbles_lagrange": {"title": "Euler–Lagrange Bubble Model", "category": "bubbles"}, - "qbmm": {"title": "Quadrature-Based Moment Method (QBMM)", "category": "bubbles"}, - "polydisperse": {"title": "Polydisperse Bubble Dynamics", "category": "bubbles"}, - "mhd": {"title": "Magnetohydrodynamics (MHD)", "category": "physics"}, - "alt_soundspeed": {"title": "Alternative Sound Speed", "category": "physics"}, - "surface_tension": {"title": "Surface Tension Model", "category": "physics"}, - "hypoelasticity": {"title": "Hypoelasticity", "category": "physics"}, - "hyperelasticity": {"title": "Hyperelasticity", "category": "physics"}, - "relax": {"title": "Phase Change (Relaxation)", "category": "physics"}, - "viscous": {"title": "Viscosity", "category": "physics"}, - "acoustic_source": {"title": "Acoustic Sources", "category": "physics"}, - "ib": {"title": "Immersed Boundaries", "category": "geometry"}, - "cyl_coord": {"title": "Cylindrical Coordinates", "category": "geometry"}, - "weno_order": {"title": "WENO Order", "category": "numerics"}, - "muscl_order": {"title": "MUSCL Order", "category": "numerics"}, - "riemann_solver": {"title": "Riemann Solver", "category": "numerics"}, - "model_eqns": {"title": "Model Equations", "category": "fundamentals"}, - "num_fluids": {"title": "Number of Fluids", "category": "fundamentals"}, -} - - -def feature_title(param: str) -> str: - meta = FEATURE_META.get(param) - if meta and "title" in meta: - return meta["title"] - return param +from mfc.params import CONSTRAINTS, DEPENDENCIES, get_value_label # noqa: E402 pylint: disable=wrong-import-position +from mfc.params.ast_analyzer import ( # noqa: E402 pylint: disable=wrong-import-position + Rule, classify_message, feature_title, + analyze_case_validator, +) # --------------------------------------------------------------------------- @@ -463,36 +188,24 @@ def summarize_case_params(params: Dict[str, Any]) -> Dict[str, Any]: def get_model_name(model_eqns: int | None) -> str: - """Get human-friendly model name""" - models = { - 1: "π-γ (Compressible Euler)", - 2: "5-Equation (Multiphase)", - 3: "6-Equation (Phase Change)", - 4: "4-Equation (Single Component)" - } - return models.get(model_eqns, "Not specified") + """Get human-friendly model name from schema.""" + if model_eqns is None: + return "Not specified" + return get_value_label("model_eqns", model_eqns) or "Not specified" def get_riemann_solver_name(solver: int | None) -> str: - """Get Riemann solver name""" - solvers = { - 1: "HLL", - 2: "HLLC", - 3: "Exact", - 4: "HLLD", - 5: "Lax-Friedrichs" - } - return solvers.get(solver, "Not specified") + """Get Riemann solver name from schema.""" + if solver is None: + return "Not specified" + return get_value_label("riemann_solver", solver) or "Not specified" def get_time_stepper_name(stepper: int | None) -> str: - """Get time stepper name""" - steppers = { - 1: "RK1 (Forward Euler)", - 2: "RK2", - 3: "RK3 (SSP)" - } - return steppers.get(stepper, "Not specified") + """Get time stepper name from schema.""" + if stepper is None: + return "Not specified" + return get_value_label("time_stepper", stepper) or "Not specified" def render_playbook_card(entry: PlaybookEntry, summary: Dict[str, Any]) -> str: # pylint: disable=too-many-branches,too-many-statements @@ -778,67 +491,120 @@ def render_markdown(rules: Iterable[Rule]) -> str: # pylint: disable=too-many-l lines.append("") - # 3. Model Equations + # 3. Model Equations (data-driven from schema) lines.append("## 🔢 Model Equations\n") lines.append("Choose your governing equations:\n") lines.append("") - lines.append("
") - lines.append("Model 1: π-γ (Compressible Euler)\n") - lines.append("- **Use for:** Single-fluid compressible flow") - lines.append("- **Value:** `model_eqns = 1`") - lines.append("- **Note:** Cannot use `num_fluids`, bubbles, or certain WENO variants") - lines.append("
\n") + def _format_model_requirements(val: int) -> str: + """Auto-generate requirements string from DEPENDENCIES['model_eqns']['when_value'].""" + me_dep = DEPENDENCIES.get("model_eqns", {}) + wv = me_dep.get("when_value", {}).get(val, {}) + if not wv: + return "" + parts = [] + if "requires" in wv: + parts.extend(f"Set `{r}`" for r in wv["requires"]) + if "requires_value" in wv: + for rv_param, rv_vals in wv["requires_value"].items(): + labeled = [f"`{v}` ({get_value_label(rv_param, v)})" for v in rv_vals] + parts.append(f"`{rv_param}` = {' or '.join(labeled)}") + return ", ".join(parts) + + # Curated editorial notes keyed by model_eqns value + _model_notes = { + 1: { + "use_for": "Single-fluid compressible flow", + "note": "Cannot use `num_fluids`, bubbles, or certain WENO variants", + }, + 2: { + "use_for": "Multiphase, bubbles, elastic materials, MHD", + "note": "Compatible with most physics models", + }, + 3: { + "use_for": "Phase change, cavitation", + "note": "Not compatible with bubbles or 3D cylindrical", + }, + 4: { + "use_for": "Single-component flows with bubbles", + }, + } - lines.append("
") - lines.append("Model 2: 5-Equation (Most versatile)\n") - lines.append("- **Use for:** Multiphase, bubbles, elastic materials, MHD") - lines.append("- **Value:** `model_eqns = 2`") - lines.append("- **Requirements:** Set `num_fluids`") - lines.append("- **Compatible with:** Most physics models") - lines.append("
\n") + # Auto-populate requirements from schema + for _val, _note in _model_notes.items(): + _auto_reqs = _format_model_requirements(_val) + if _auto_reqs: + _note["requirements"] = _auto_reqs - lines.append("
") - lines.append("Model 3: 6-Equation (Phase change)\n") - lines.append("- **Use for:** Phase change, cavitation") - lines.append("- **Value:** `model_eqns = 3`") - lines.append("- **Requirements:** `riemann_solver = 2` (HLLC), `avg_state = 2`, `wave_speeds = 1`") - lines.append("- **Note:** Not compatible with bubbles or 3D cylindrical") - lines.append("
\n") + model_constraint = CONSTRAINTS["model_eqns"] + for val in model_constraint["choices"]: + label = get_value_label("model_eqns", val) + notes = _model_notes.get(val, {}) + lines.append("
") + lines.append(f"Model {val}: {label}\n") + if "use_for" in notes: + lines.append(f"- **Use for:** {notes['use_for']}") + lines.append(f"- **Value:** `model_eqns = {val}`") + if "requirements" in notes: + lines.append(f"- **Requirements:** {notes['requirements']}") + if "note" in notes: + lines.append(f"- **Note:** {notes['note']}") + lines.append("
\n") - lines.append("
") - lines.append("Model 4: 4-Equation (Single component)\n") - lines.append("- **Use for:** Single-component flows with bubbles") - lines.append("- **Value:** `model_eqns = 4`") - lines.append("- **Requirements:** `num_fluids = 1`, set `rhoref` and `pref`") - lines.append("
\n") + # 4. Riemann Solvers (data-driven from schema) + # Curated editorial notes keyed by riemann_solver value + _solver_notes = { + 1: {"best_for": "MHD, elastic materials", "requirements": "—"}, + 2: {"best_for": "Bubbles, phase change, multiphase", "requirements": "`avg_state=2` for bubbles"}, + 3: {"best_for": "High accuracy (expensive)", "requirements": "—"}, + 4: {"best_for": "MHD (advanced)", "requirements": "MHD only, no relativity"}, + 5: {"best_for": "Robust fallback", "requirements": "Not with cylindrical+viscous"}, + } - # 4. Riemann Solvers (simplified) lines.append("## ⚙️ Riemann Solvers\n") lines.append("| Solver | `riemann_solver` | Best For | Requirements |") lines.append("|--------|-----------------|----------|-------------|") - lines.append("| **HLL** | `1` | MHD, elastic materials | — |") - lines.append("| **HLLC** | `2` | Bubbles, phase change, multiphase | `avg_state=2` for bubbles |") - lines.append("| **Exact** | `3` | High accuracy (expensive) | — |") - lines.append("| **HLLD** | `4` | MHD (advanced) | MHD only, no relativity |") - lines.append("| **Lax-Friedrichs** | `5` | Robust fallback | Not with cylindrical+viscous |") + + solver_constraint = CONSTRAINTS["riemann_solver"] + for val in solver_constraint["choices"]: + label = get_value_label("riemann_solver", val) + notes = _solver_notes.get(val, {}) + best = notes.get("best_for", "—") + reqs = notes.get("requirements", "—") + lines.append(f"| **{label}** | `{val}` | {best} | {reqs} |") + lines.append("") - # 5. Bubble Models (enhanced with collapsible) + # 5. Bubble Models (data-driven from schema dependencies + curated notes) if "bubbles_euler" in by_param or "bubbles_lagrange" in by_param: lines.append("## 💧 Bubble Models\n") lines.append("") + # Euler-Euler: inject schema dependency info (data-driven) lines.append("
") lines.append("Euler-Euler (`bubbles_euler`)\n") lines.append("**Requirements:**") - lines.append("- `model_eqns = 2` or `4`") - lines.append("- `riemann_solver = 2` (HLLC)") - lines.append("- `avg_state = 2`") - lines.append("- Set `nb` (number of bins) ≥ 1\n") + be_dep = DEPENDENCIES.get("bubbles_euler", {}) + be_when_true = be_dep.get("when_true", {}) + be_rv = be_when_true.get("requires_value", {}) + for rv_param, rv_vals in be_rv.items(): + labeled = [f"`{v}` ({get_value_label(rv_param, v)})" for v in rv_vals] + lines.append(f"- `{rv_param}` = {' or '.join(labeled)}") + be_recs = be_when_true.get("recommends", []) + if be_recs: + lines.append(f"- Recommended to also set: {', '.join(f'`{r}`' for r in be_recs)}") + lines.append("") lines.append("**Extensions:**") - lines.append("- `polydisperse = T`: Multiple bubble sizes (requires odd `nb > 1`)") - lines.append("- `qbmm = T`: Quadrature method (requires `nnode = 4`)") + # Inject polydisperse dependency + pd_dep = DEPENDENCIES.get("polydisperse", {}) + pd_reqs = pd_dep.get("when_true", {}).get("requires", []) + pd_req_str = f" (requires {', '.join(f'`{r}`' for r in pd_reqs)})" if pd_reqs else "" + lines.append(f"- `polydisperse = T`: Multiple bubble sizes{pd_req_str}, odd `nb > 1`") + # Inject qbmm dependency + qb_dep = DEPENDENCIES.get("qbmm", {}) + qb_recs = qb_dep.get("when_true", {}).get("recommends", []) + qb_rec_str = f" (recommends {', '.join(f'`{r}`' for r in qb_recs)})" if qb_recs else "" + lines.append(f"- `qbmm = T`: Quadrature method{qb_rec_str}, requires `nnode = 4`") lines.append("- `adv_n = T`: Number density advection (requires `num_fluids = 1`)") lines.append("
\n") @@ -851,37 +617,30 @@ def render_markdown(rules: Iterable[Rule]) -> str: # pylint: disable=too-many-l lines.append("**Note:** Tracks individual bubbles") lines.append("\n") - # 6. Condensed Parameter Reference + # 6. Condensed Parameter Reference (auto-collected from schema) lines.append("## 📖 Quick Parameter Reference\n") lines.append("Key parameters and their constraints:\n") - # Highlight only the most important parameters in collapsible sections - important_params = { - "MHD": "mhd", - "Surface Tension": "surface_tension", - "Viscosity": "viscous", - "Number of Fluids": "num_fluids", - "Cylindrical Coordinates": "cyl_coord", - "Immersed Boundaries": "ib", - } + # Auto-collect all params that have CONSTRAINTS or DEPENDENCIES entries + quick_ref_params = sorted(set(CONSTRAINTS.keys()) | set(DEPENDENCIES.keys())) - for title, param in important_params.items(): - if param not in by_param: - continue + for param in quick_ref_params: + title = feature_title(param) - rules_for_param = by_param[param] + # Gather schema info + constraint = CONSTRAINTS.get(param, {}) + dep = DEPENDENCIES.get(param, {}) - # Get key info + # Gather AST-extracted rules + rules_for_param = by_param.get(param, []) requirements = [] incompatibilities = [] ranges = [] for rule in rules_for_param: msg = rule.message - # Skip IGR-related messages if "IGR" in msg: continue - kind = classify_message(msg) if kind == "requirement": requirements.append(msg) @@ -890,12 +649,65 @@ def render_markdown(rules: Iterable[Rule]) -> str: # pylint: disable=too-many-l elif kind == "range": ranges.append(msg) - if not (requirements or incompatibilities or ranges): + # Build schema constraint summary + schema_parts = [] + if "choices" in constraint: + labels = constraint.get("value_labels", {}) + if labels: + items = [f"`{v}` = {labels[v]}" for v in constraint["choices"] if v in labels] + schema_parts.append("Choices: " + ", ".join(items)) + else: + schema_parts.append(f"Choices: {constraint['choices']}") + if "min" in constraint: + schema_parts.append(f"Min: {constraint['min']}") + if "max" in constraint: + schema_parts.append(f"Max: {constraint['max']}") + + # Build dependency summary + dep_parts = [] + + def _render_cond_parts(trigger_str, cond_dict): + """Render a condition dict into dep_parts entries.""" + if "requires" in cond_dict: + dep_parts.append(f"When {trigger_str}, requires: {', '.join(f'`{r}`' for r in cond_dict['requires'])}") + if "requires_value" in cond_dict: + rv_items = [] + for rv_p, rv_vs in cond_dict["requires_value"].items(): + labeled = [f"`{v}` ({get_value_label(rv_p, v)})" for v in rv_vs] + rv_items.append(f"`{rv_p}` = {' or '.join(labeled)}") + dep_parts.append(f"When {trigger_str}, requires {', '.join(rv_items)}") + if "recommends" in cond_dict: + dep_parts.append(f"When {trigger_str}, recommends: {', '.join(f'`{r}`' for r in cond_dict['recommends'])}") + + for cond_key in ["when_true", "when_set"]: + cond = dep.get(cond_key, {}) + if cond: + trigger = "enabled" if cond_key == "when_true" else "set" + _render_cond_parts(trigger, cond) + + if "when_value" in dep: + for wv_val, wv_cond in dep["when_value"].items(): + _render_cond_parts(f"= {wv_val}", wv_cond) + + # Skip if nothing to show + if not (schema_parts or dep_parts or requirements or incompatibilities or ranges): continue lines.append(f"\n
") lines.append(f"{title} (`{param}`)\n") + if schema_parts: + lines.append("**Schema constraints:**") + for sp in schema_parts: + lines.append(f"- {sp}") + lines.append("") + + if dep_parts: + lines.append("**Dependencies:**") + for dp in dep_parts: + lines.append(f"- {dp}") + lines.append("") + if requirements: lines.append("**Requirements:**") for req in requirements[:3]: @@ -929,20 +741,8 @@ def render_markdown(rules: Iterable[Rule]) -> str: # pylint: disable=too-many-l # --------------------------------------------------------------------------- def main() -> None: - src = CASE_VALIDATOR_PATH.read_text(encoding="utf-8") - tree = ast.parse(src, filename=str(CASE_VALIDATOR_PATH)) - - analyzer = CaseValidatorAnalyzer() - analyzer.visit(tree) - - # Infer stages per method from call graph - method_stages = compute_method_stages(analyzer.call_graph) - - # Attach stages to rules - for r in analyzer.rules: - r.stages = method_stages.get(r.method, set()) - - md = render_markdown(analyzer.rules) + analysis = analyze_case_validator(CASE_VALIDATOR_PATH) + md = render_markdown(analysis["rules"]) print(md) diff --git a/toolchain/mfc/params/__init__.py b/toolchain/mfc/params/__init__.py index 079f04768b..57e1912276 100644 --- a/toolchain/mfc/params/__init__.py +++ b/toolchain/mfc/params/__init__.py @@ -25,5 +25,9 @@ # IMPORTANT: This import populates REGISTRY with all parameter definitions # and freezes it. It must come after REGISTRY is imported and must not be removed. from . import definitions # noqa: F401 pylint: disable=unused-import +from .definitions import CONSTRAINTS, DEPENDENCIES, get_value_label -__all__ = ['REGISTRY', 'RegistryFrozenError', 'ParamDef', 'ParamType'] +__all__ = [ + 'REGISTRY', 'RegistryFrozenError', 'ParamDef', 'ParamType', + 'CONSTRAINTS', 'DEPENDENCIES', 'get_value_label', +] diff --git a/toolchain/mfc/params/ast_analyzer.py b/toolchain/mfc/params/ast_analyzer.py new file mode 100644 index 0000000000..692297bfc7 --- /dev/null +++ b/toolchain/mfc/params/ast_analyzer.py @@ -0,0 +1,767 @@ +""" +Shared AST analyzer for case_validator.py. + +Extracts all `self.prohibit(...)` rules from CaseValidator, determines +which parameter "triggers" each rule, and provides convenience functions +for both doc generators (parameters.md and case_constraints.md). +""" + +from __future__ import annotations + +import ast +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set +from collections import defaultdict + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class Rule: + method: str # e.g. "check_igr_simulation" + lineno: int # line number of the prohibit call + params: List[str] # case parameter names used in condition + message: str # user-friendly error message + stages: Set[str] = field(default_factory=set) # e.g. {"simulation", "pre_process"} + trigger: Optional[str] = None # param that "owns" this rule + + +# --------------------------------------------------------------------------- +# F-string message extraction +# --------------------------------------------------------------------------- + +def _extract_message(node: ast.AST) -> Optional[str]: + """ + Extract the message string from a prohibit() call's second argument. + + Handles both plain strings (ast.Constant) and f-strings (ast.JoinedStr). + For f-strings, FormattedValue expressions are replaced with their + unparsed source representation, giving a readable approximation. + """ + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + + if isinstance(node, ast.JoinedStr): + parts = [] + for value in node.values: + if isinstance(value, ast.Constant): + parts.append(str(value.value)) + elif isinstance(value, ast.FormattedValue): + # Unparse the expression to get a readable approximation + try: + parts.append(ast.unparse(value.value)) + except Exception: # pylint: disable=broad-except + parts.append("?") + else: + parts.append("?") + return "".join(parts) + + return None + + +def _resolve_fstring(node: ast.JoinedStr, subs: Dict[str, str]) -> Optional[str]: + """Resolve a JoinedStr (f-string) by substituting known loop variables.""" + parts: List[str] = [] + for v in node.values: + if isinstance(v, ast.Constant): + parts.append(str(v.value)) + elif isinstance(v, ast.FormattedValue): + if isinstance(v.value, ast.Name) and v.value.id in subs: + parts.append(subs[v.value.id]) + else: + try: + parts.append(ast.unparse(v.value)) + except Exception: # pylint: disable=broad-except + parts.append("?") + else: + parts.append("?") + return "".join(parts) + + +def _resolve_message(msg_node: ast.AST, subs: Dict[str, str]) -> Optional[str]: + """Resolve a prohibit message, substituting loop variables in f-strings.""" + if isinstance(msg_node, ast.Constant) and isinstance(msg_node.value, str): + return msg_node.value + if isinstance(msg_node, ast.JoinedStr): + return _resolve_fstring(msg_node, subs) + return None + + +def _is_self_get(call: ast.Call) -> bool: + """Check if a Call node is self.get(...).""" + return (isinstance(call.func, ast.Attribute) + and isinstance(call.func.value, ast.Name) + and call.func.value.id == "self" + and call.func.attr == "get" + and bool(call.args)) + + +# --------------------------------------------------------------------------- +# AST analysis: methods, call graph, rules +# --------------------------------------------------------------------------- + +class CaseValidatorAnalyzer(ast.NodeVisitor): # pylint: disable=too-many-instance-attributes + """ + Analyzes the CaseValidator class: + + - collects all methods + - builds a call graph between methods + - extracts all self.prohibit(...) rules + """ + + def __init__(self): + super().__init__() + self.in_case_validator = False + self.current_method: str | None = None + + self.methods: Dict[str, ast.FunctionDef] = {} + self.call_graph: Dict[str, Set[str]] = defaultdict(set) + self.rules: List[Rule] = [] + + # Stack of {local_name -> param_name} maps, one per method + self.local_param_stack: List[Dict[str, str]] = [] + + # Stack of {alias_name -> [source_param, ...]} maps (parallel to local_param_stack) + self.alias_map_stack: List[Dict[str, List[str]]] = [] + + # {method_name -> trigger_param} from guard detection + self._method_guards: Dict[str, str] = {} + + # Line numbers of prohibit calls handled by loop expansion (skip in visit_Call) + self._expanded_prohibit_lines: Set[int] = set() + + # --- top-level entrypoint --- + + def visit_ClassDef(self, node: ast.ClassDef): + if node.name == "CaseValidator": + self.in_case_validator = True + # collect methods + for item in node.body: + if isinstance(item, ast.FunctionDef): + self.methods[item.name] = item + # now analyze all methods + for method in self.methods.values(): + self._analyze_method(method) + self.in_case_validator = False + else: + self.generic_visit(node) + + # --- per-method analysis --- + + def _analyze_method(self, func: ast.FunctionDef): + """Analyze a single method: local param mapping, call graph, rules.""" + self.current_method = func.name + local_param_map = self._build_local_param_map(func) + alias_map = self._build_alias_map(func, local_param_map) + self.local_param_stack.append(local_param_map) + self.alias_map_stack.append(alias_map) + + # Detect method guard pattern + guard = _extract_method_guard(func, local_param_map) + if guard: + self._method_guards[func.name] = guard + + # Expand literal-list for-loops before generic_visit + self._expand_literal_loops(func, local_param_map) + + self.generic_visit(func) + + # Enrich rules with params from if-guard conditions + self._enrich_rules_with_if_guards(func, local_param_map, alias_map) + + self.alias_map_stack.pop() + self.local_param_stack.pop() + self.current_method = None + + def _enrich_rules_with_if_guards(self, func: ast.FunctionDef, + local_param_map: Dict[str, str], + alias_map: Dict[str, List[str]]): + """ + After rules are extracted, walk the function body for ast.If nodes. + For each if-block, extract guard params from the test condition and add + them to every rule whose lineno falls within the block's line range. + """ + for node in ast.walk(func): # pylint: disable=too-many-nested-blocks + if not isinstance(node, ast.If): + continue + # Extract params from the if-test condition + guard_params = _extract_test_params(node.test, local_param_map, alias_map) + if not guard_params: + continue + # Determine line ranges for body and orelse + ranges = [] + if node.body: + body_start = node.body[0].lineno + body_end = node.body[-1].end_lineno or node.body[-1].lineno + ranges.append((body_start, body_end)) + if node.orelse: + else_start = node.orelse[0].lineno + else_end = node.orelse[-1].end_lineno or node.orelse[-1].lineno + ranges.append((else_start, else_end)) + # Enrich matching rules + for rule in self.rules: + if rule.method != func.name: + continue + for rng_start, rng_end in ranges: + if rng_start <= rule.lineno <= rng_end: + for gp in guard_params: + if gp not in rule.params: + rule.params.append(gp) + break + + def _build_local_param_map(self, func: ast.FunctionDef) -> Dict[str, str]: # pylint: disable=too-many-nested-blocks + """ + Look for assignments like: + igr = self.get('igr', 'F') == 'T' + model_eqns = self.get('model_eqns') + and record local_name -> 'param_name'. + + Uses ast.walk to find assignments at any nesting depth (inside if/for/with blocks). + """ + m: Dict[str, str] = {} + for node in ast.walk(func): # pylint: disable=too-many-nested-blocks + if isinstance(node, ast.Assign): + # Handle both direct calls and comparisons + value = node.value + # Unwrap comparisons like "self.get('igr', 'F') == 'T'" + if isinstance(value, ast.Compare): + value = value.left + + if isinstance(value, ast.Call): + call = value + if ( # pylint: disable=too-many-boolean-expressions + isinstance(call.func, ast.Attribute) + and isinstance(call.func.value, ast.Name) + and call.func.value.id == "self" + and call.func.attr == "get" + and call.args + and isinstance(call.args[0], ast.Constant) + and isinstance(call.args[0].value, str) + ): + param_name = call.args[0].value + for target in node.targets: + if isinstance(target, ast.Name): + m[target.id] = param_name + return m + + @staticmethod + def _build_alias_map(func: ast.FunctionDef, + local_param_map: Dict[str, str]) -> Dict[str, List[str]]: + """ + Detect boolean alias assignments like: + variable_dt = cfl_dt or cfl_adap_dt # BoolOp(Or) + has_output = rho_wrt or E_wrt or ... # BoolOp(Or) with many operands + skip_check = cyl_coord and dir in [...] # BoolOp(And) + + Returns {alias_name -> [source_param, ...]}. + """ + alias_map: Dict[str, List[str]] = {} + for node in ast.walk(func): + if not isinstance(node, ast.Assign): + continue + if not isinstance(node.value, ast.BoolOp): + continue + # Collect source params from BoolOp operands + sources: List[str] = [] + for operand in node.value.values: + if isinstance(operand, ast.Name) and operand.id in local_param_map: + sources.append(local_param_map[operand.id]) + if not sources: + continue + for target in node.targets: + if isinstance(target, ast.Name): + alias_map[target.id] = sources + return alias_map + + # --- literal-list for-loop expansion --- + + def _expand_literal_loops(self, func: ast.FunctionDef, local_param_map: Dict[str, str]): + """Expand `for var in [x, y, z]:` loops into concrete Rules.""" + self._expand_loop_stmts(func.body, func.name, local_param_map, {}) + + def _expand_loop_stmts(self, stmts: list, method_name: str, + parent_map: Dict[str, str], subs: Dict[str, str]): + """Recursively find literal-list for-loops and create expanded Rules.""" + for stmt in stmts: + if (isinstance(stmt, ast.For) + and isinstance(stmt.target, ast.Name) + and isinstance(stmt.iter, ast.List) + and all(isinstance(e, ast.Constant) for e in stmt.iter.elts)): + var = stmt.target.id + for elt in stmt.iter.elts: + new_subs = {**subs, var: str(elt.value)} + loop_map = self._resolve_loop_gets(stmt.body, new_subs) + merged = {**parent_map, **loop_map} + # Detect loop-body guard: `if not : continue` + loop_guard = self._detect_loop_guard(stmt.body, merged) + # Recurse for nested literal-list loops + self._expand_loop_stmts(stmt.body, method_name, merged, new_subs) + # Create Rules for prohibit calls at this level + self._create_loop_rules(stmt.body, method_name, merged, new_subs, loop_guard) + elif subs: + # Inside an expanded loop: recurse into if/else blocks + if isinstance(stmt, ast.If): + self._expand_loop_stmts(stmt.body, method_name, parent_map, subs) + if stmt.orelse: + self._expand_loop_stmts(stmt.orelse, method_name, parent_map, subs) + + @staticmethod + def _detect_loop_guard(stmts: list, local_map: Dict[str, str]) -> Optional[str]: + """Detect `if not : continue` guard pattern in a loop body.""" + for stmt in stmts: + if not isinstance(stmt, ast.If): + continue + test = stmt.test + if (isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not) + and isinstance(test.operand, ast.Name)): + if (len(stmt.body) == 1 and isinstance(stmt.body[0], ast.Continue)): + var_name = test.operand.id + if var_name in local_map: + return local_map[var_name] + return None + + @staticmethod + def _resolve_loop_gets(stmts: list, subs: Dict[str, str]) -> Dict[str, str]: + """Resolve self.get() assignments in loop body, substituting f-string loop vars.""" + m: Dict[str, str] = {} + for stmt in stmts: + for node in ast.walk(stmt): + if not isinstance(node, ast.Assign): + continue + value = node.value + if isinstance(value, ast.Compare): + value = value.left + if not isinstance(value, ast.Call) or not _is_self_get(value): + continue + arg = value.args[0] + if isinstance(arg, ast.JoinedStr): + resolved = _resolve_fstring(arg, subs) + if resolved is None: + continue + param_name = resolved + elif isinstance(arg, ast.Constant) and isinstance(arg.value, str): + param_name = arg.value + else: + continue + for target in node.targets: + if isinstance(target, ast.Name): + m[target.id] = param_name + return m + + def _create_loop_rules(self, stmts: list, method_name: str, # pylint: disable=too-many-arguments,too-many-positional-arguments + local_map: Dict[str, str], subs: Dict[str, str], + loop_guard: Optional[str] = None): + """Create Rules for self.prohibit() calls found in loop body statements.""" + for stmt in stmts: + # Skip nested literal-list for-loops (handled by recursion) + if (isinstance(stmt, ast.For) + and isinstance(stmt.target, ast.Name) + and isinstance(stmt.iter, ast.List) + and all(isinstance(e, ast.Constant) for e in stmt.iter.elts)): + continue + for node in ast.walk(stmt): + if not isinstance(node, ast.Call): + continue + if not (isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and node.func.attr == "prohibit" + and len(node.args) >= 2): + continue + condition, msg_node = node.args[0], node.args[1] + msg = _resolve_message(msg_node, subs) + if msg is None: + continue + param_set = self._extract_params_with_subs(condition, local_map, subs) + # Use loop-body guard as trigger and include it in params + if loop_guard: + trigger = loop_guard + param_set.add(loop_guard) + else: + trigger = self._determine_trigger( + sorted(param_set), condition, local_map) + params = sorted(param_set) + rule = Rule( + method=method_name, + lineno=node.lineno, + params=params, + message=msg, + trigger=trigger, + ) + self.rules.append(rule) + self._expanded_prohibit_lines.add(node.lineno) + + def _extract_params_with_subs(self, condition: ast.AST, + local_map: Dict[str, str], + subs: Dict[str, str]) -> Set[str]: + """Like _extract_params but also resolves JoinedStr self.get() args.""" + params: Set[str] = set() + for node in ast.walk(condition): + if isinstance(node, ast.Name) and node.id in local_map: + params.add(local_map[node.id]) + if isinstance(node, ast.Call) and _is_self_get(node): + arg = node.args[0] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + params.add(arg.value) + elif isinstance(arg, ast.JoinedStr): + resolved = _resolve_fstring(arg, subs) + if resolved: + params.add(resolved) + return params + + # --- visit calls to build call graph + rules --- + + def visit_Call(self, node: ast.Call): + # record method call edges: self.some_method(...) + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and isinstance(node.func.attr, str) + ): + callee = node.func.attr + if self.current_method is not None: + # method call on self + self.call_graph[self.current_method].add(callee) + + # detect self.prohibit(, "") + # Skip prohibit calls already handled by loop expansion + if ( # pylint: disable=too-many-boolean-expressions + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and node.func.attr == "prohibit" + and len(node.args) >= 2 + and node.lineno not in self._expanded_prohibit_lines + ): + condition, msg_node = node.args[0], node.args[1] + msg = _extract_message(msg_node) + if msg is not None: + local_map = self.local_param_stack[-1] if self.local_param_stack else {} + params = sorted(self._extract_params(condition)) + trigger = self._determine_trigger(params, condition, local_map) + rule = Rule( + method=self.current_method or "", + lineno=node.lineno, + params=params, + message=msg, + trigger=trigger, + ) + self.rules.append(rule) + + self.generic_visit(node) + + def _determine_trigger(self, _params: List[str], condition: ast.AST, + local_map: Dict[str, str]) -> Optional[str]: + """Determine trigger param: method guard first, then condition fallback.""" + # 1. Method guard (high confidence) + if self.current_method and self.current_method in self._method_guards: + return self._method_guards[self.current_method] + + # 2. Condition first-param fallback (with alias resolution) + alias_map = self.alias_map_stack[-1] if self.alias_map_stack else {} + return _extract_trigger_from_condition(condition, local_map, alias_map) + + def _extract_params(self, condition: ast.AST) -> Set[str]: + """ + Collect parameter names used in the condition via: + - local variables mapped from self.get(...) + - boolean aliases (variable_dt → [cfl_dt, cfl_adap_dt]) + - direct self.get('param_name', ...) calls + """ + params: Set[str] = set() + local_map = self.local_param_stack[-1] if self.local_param_stack else {} + alias_map = self.alias_map_stack[-1] if self.alias_map_stack else {} + + for node in ast.walk(condition): + # local names + if isinstance(node, ast.Name): + if node.id in local_map: + params.add(local_map[node.id]) + elif node.id in alias_map: + params.update(alias_map[node.id]) + + # direct self.get('param_name') + if isinstance(node, ast.Call): + if ( # pylint: disable=too-many-boolean-expressions + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and node.func.attr == "get" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): + params.add(node.args[0].value) + + return params + + +# --------------------------------------------------------------------------- +# Trigger detection helpers +# --------------------------------------------------------------------------- + +def _extract_method_guard(func: ast.FunctionDef, local_param_map: Dict[str, str]) -> Optional[str]: + """ + Detect early-return guard patterns like: + if not bubbles_euler: + return + The guarded variable's param is the trigger for all rules in that method. + """ + for stmt in func.body: # pylint: disable=too-many-nested-blocks + if not isinstance(stmt, ast.If): + continue + + # Check for "if not : return" pattern + test = stmt.test + if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not): + if isinstance(test.operand, ast.Name): + var_name = test.operand.id + # Check body is just "return" + if (len(stmt.body) == 1 + and isinstance(stmt.body[0], ast.Return) + and stmt.body[0].value is None): + if var_name in local_param_map: + return local_param_map[var_name] + + # Check for "if != : return" pattern + # e.g. "if recon_type != 1: return" or "if model_eqns != 3: return" + if isinstance(test, ast.Compare) and len(test.ops) == 1: + if isinstance(test.ops[0], ast.NotEq): + if isinstance(test.left, ast.Name): + var_name = test.left.id + if (len(stmt.body) == 1 + and isinstance(stmt.body[0], ast.Return) + and stmt.body[0].value is None): + if var_name in local_param_map: + return local_param_map[var_name] + + return None + + +def _extract_test_params(test: ast.AST, local_param_map: Dict[str, str], + alias_map: Dict[str, List[str]]) -> Set[str]: + """Extract parameter names from an if-test condition, resolving aliases.""" + params: Set[str] = set() + for node in ast.walk(test): + if isinstance(node, ast.Name): + if node.id in local_param_map: + params.add(local_param_map[node.id]) + elif node.id in alias_map: + params.update(alias_map[node.id]) + if isinstance(node, ast.Call) and _is_self_get(node): + arg = node.args[0] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + params.add(arg.value) + return params + + +def _extract_trigger_from_condition(condition: ast.AST, local_param_map: Dict[str, str], + alias_map: Optional[Dict[str, List[str]]] = None) -> Optional[str]: + """ + Fallback trigger detection: walk the condition AST left-to-right, + return the first parameter name found. Resolves aliases to their first source param. + """ + if alias_map is None: + alias_map = {} + # Walk in source order (left-to-right in boolean expressions) + for node in ast.walk(condition): + if isinstance(node, ast.Name): + if node.id in local_param_map: + return local_param_map[node.id] + if node.id in alias_map and alias_map[node.id]: + return alias_map[node.id][0] + if isinstance(node, ast.Call): + if ( # pylint: disable=too-many-boolean-expressions + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + and node.func.attr == "get" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): + return node.args[0].value + return None + + +# --------------------------------------------------------------------------- +# Stage inference from validate_* roots and call graph +# --------------------------------------------------------------------------- + +STAGE_ROOTS: Dict[str, List[str]] = { + "common": ["validate_common"], + "simulation": ["validate_simulation"], + "pre_process": ["validate_pre_process"], + "post_process": ["validate_post_process"], +} + + +def compute_method_stages(call_graph: Dict[str, Set[str]]) -> Dict[str, Set[str]]: + """ + For each stage (simulation/pre_process/post_process/common), starting from + validate_* roots, walk the call graph and record which methods belong to which stages. + """ + method_stages: Dict[str, Set[str]] = defaultdict(set) + + def dfs(start: str, stage: str): + stack = [start] + visited: Set[str] = set() + while stack: + m = stack.pop() + if m in visited: + continue + visited.add(m) + method_stages[m].add(stage) + for nxt in call_graph.get(m, ()): + if nxt not in visited: + stack.append(nxt) + + for stage, roots in STAGE_ROOTS.items(): + for root in roots: + dfs(root, stage) + + return method_stages + + +# --------------------------------------------------------------------------- +# Classification of messages for nicer grouping +# --------------------------------------------------------------------------- + +def classify_message(msg: str) -> str: + """ + Roughly classify rule messages for nicer grouping. + + Returns one of: "requirement", "incompatibility", "range", "other". + """ + text = msg.lower() + + if ( # pylint: disable=too-many-boolean-expressions + "not compatible" in text + or "does not support" in text + or "cannot be used" in text + or "must not" in text + or "is not supported" in text + or "incompatible" in text + or "untested" in text + or "not available" in text + or "is not compatible" in text + or "activate only one" in text + ): + return "incompatibility" + + if ( # pylint: disable=too-many-boolean-expressions + "requires" in text + or "must be set if" in text + or "must be specified" in text + or "must be set with" in text + or "can only be enabled if" in text + or "must be set for" in text + or "must be set when" in text + ): + return "requirement" + + if ( # pylint: disable=too-many-boolean-expressions + "must be between" in text + or "must be positive" in text + or "must be non-negative" in text + or "must be greater than" in text + or "must be less than" in text + or "must be at least" in text + or "must be <=" in text + or "must be >=" in text + or "must be odd" in text + or "divisible by" in text + or re.search(r"must be 1\b", text) is not None + or "must be 'T' or 'F'" in text + ): + return "range" + + return "other" + + +# Optional: nicer display names / categories (you can extend this) +FEATURE_META = { + "igr": {"title": "Iterative Generalized Riemann (IGR)", "category": "solver"}, + "bubbles_euler": {"title": "Euler-Euler Bubble Model", "category": "bubbles"}, + "bubbles_lagrange": {"title": "Euler-Lagrange Bubble Model", "category": "bubbles"}, + "qbmm": {"title": "Quadrature-Based Moment Method (QBMM)", "category": "bubbles"}, + "polydisperse": {"title": "Polydisperse Bubble Dynamics", "category": "bubbles"}, + "mhd": {"title": "Magnetohydrodynamics (MHD)", "category": "physics"}, + "alt_soundspeed": {"title": "Alternative Sound Speed", "category": "physics"}, + "surface_tension": {"title": "Surface Tension Model", "category": "physics"}, + "hypoelasticity": {"title": "Hypoelasticity", "category": "physics"}, + "hyperelasticity": {"title": "Hyperelasticity", "category": "physics"}, + "relax": {"title": "Phase Change (Relaxation)", "category": "physics"}, + "viscous": {"title": "Viscosity", "category": "physics"}, + "acoustic_source": {"title": "Acoustic Sources", "category": "physics"}, + "ib": {"title": "Immersed Boundaries", "category": "geometry"}, + "cyl_coord": {"title": "Cylindrical Coordinates", "category": "geometry"}, + "weno_order": {"title": "WENO Order", "category": "numerics"}, + "muscl_order": {"title": "MUSCL Order", "category": "numerics"}, + "riemann_solver": {"title": "Riemann Solver", "category": "numerics"}, + "model_eqns": {"title": "Model Equations", "category": "fundamentals"}, + "num_fluids": {"title": "Number of Fluids", "category": "fundamentals"}, +} + + +def feature_title(param: str) -> str: + meta = FEATURE_META.get(param) + if meta and "title" in meta: + return meta["title"] + return param + + +# --------------------------------------------------------------------------- +# Convenience: full analysis pipeline +# --------------------------------------------------------------------------- + +_DEFAULT_VALIDATOR_PATH = Path(__file__).resolve().parent.parent / "case_validator.py" + + +def analyze_case_validator(path: Optional[Path] = None) -> Dict: + """ + Parse case_validator.py and return extracted rules indexed multiple ways. + + Returns a dict with: + rules: List[Rule] - all rules + by_trigger: Dict[str, List[Rule]] - rules indexed by trigger param + by_param: Dict[str, List[Rule]] - rules indexed by all mentioned params + """ + if path is None: + path = _DEFAULT_VALIDATOR_PATH + + src = path.read_text(encoding="utf-8") + tree = ast.parse(src, filename=str(path)) + + analyzer = CaseValidatorAnalyzer() + analyzer.visit(tree) + + # Infer stages per method from call graph + method_stages = compute_method_stages(analyzer.call_graph) + + # Attach stages to rules + for r in analyzer.rules: + r.stages = method_stages.get(r.method, set()) + + # Build indices + by_trigger: Dict[str, List[Rule]] = defaultdict(list) + by_param: Dict[str, List[Rule]] = defaultdict(list) + + for r in analyzer.rules: + if r.trigger: + by_trigger[r.trigger].append(r) + for p in r.params: + by_param[p].append(r) + + return { + "rules": analyzer.rules, + "by_trigger": dict(by_trigger), + "by_param": dict(by_param), + "call_graph": analyzer.call_graph, + "methods": analyzer.methods, + } diff --git a/toolchain/mfc/params/definitions.py b/toolchain/mfc/params/definitions.py index f3807dd451..4ef006a304 100644 --- a/toolchain/mfc/params/definitions.py +++ b/toolchain/mfc/params/definitions.py @@ -3,7 +3,7 @@ Single file containing all ~3,300 parameter definitions using loops. This replaces the definitions/ directory. -""" +""" # pylint: disable=too-many-lines import re from typing import Dict, Any @@ -163,7 +163,7 @@ "t_step_print": "Print interval (steps)", "t_stop": "Stop time", "t_save": "Save interval (time)", - "time_stepper": "Time integration scheme (1=Euler, 2=RK2, 3=RK3)", + "time_stepper": "Time integration scheme", "cfl_target": "Target CFL number", "cfl_max": "Maximum CFL number", "cfl_adap_dt": "Enable adaptive CFL time stepping", @@ -174,7 +174,7 @@ "adap_dt_max_iters": "Max iterations for adaptive dt", "t_tol": "Time tolerance", # Model - "model_eqns": "Model equations (1=gamma, 2=5-eq, 3=6-eq, 4=4-eq)", + "model_eqns": "Model equations", "num_fluids": "Number of fluids", "num_patches": "Number of IC patches", "mpp_lim": "Mixture pressure positivity limiter", @@ -186,9 +186,9 @@ "teno": "Enable TENO", "mp_weno": "Enable monotonicity-preserving WENO", # Riemann - "riemann_solver": "Riemann solver (1=HLL, 2=HLLC, 3=exact)", + "riemann_solver": "Riemann solver", "wave_speeds": "Wave speed estimate method", - "avg_state": "Average state (1=Roe, 2=arithmetic)", + "avg_state": "Average state", # Physics toggles "viscous": "Enable viscous effects", "mhd": "Enable magnetohydrodynamics", @@ -229,7 +229,7 @@ "old_ic": "Load initial conditions from previous", "t_step_old": "Time step to restart from", "fd_order": "Finite difference order", - "recon_type": "Reconstruction type (1=WENO, 2=MUSCL)", + "recon_type": "Reconstruction type", "muscl_order": "MUSCL reconstruction order", "muscl_lim": "MUSCL limiter type", "low_Mach": "Low Mach number correction", @@ -237,8 +237,8 @@ "Ca": "Cavitation number", "Web": "Weber number", "Re_inv": "Inverse Reynolds number", - "format": "Output format (1=Silo, 2=binary)", - "precision": "Output precision (1=single, 2=double)", + "format": "Output format", + "precision": "Output precision", # Body forces "bf_x": "Enable body force in x", "bf_y": "Enable body force in y", @@ -346,14 +346,120 @@ def _auto_describe(name: str) -> str: } +# ============================================================================= +# Data-driven Annotations for Doc Generation +# ============================================================================= +# These dicts are the single source of truth for parameter hints in the docs. +# To annotate a new param, add an entry here instead of editing docs_gen.py. + +HINTS = { + "bc": { + "grcbc_in": "Enables GRCBC subsonic inflow (bc type -7)", + "grcbc_out": "Enables GRCBC subsonic outflow (bc type -8)", + "grcbc_vel_out": "GRCBC velocity outlet (requires `grcbc_out`)", + "vel_in": "Inlet velocity component (used with `grcbc_in`)", + "vel_out": "Outlet velocity component (used with `grcbc_vel_out`)", + "pres_in": "Inlet pressure (used with `grcbc_in`)", + "pres_out": "Outlet pressure (used with `grcbc_out`)", + "alpha_rho_in": "Inlet partial density per fluid (used with `grcbc_in`)", + "alpha_in": "Inlet volume fraction per fluid (used with `grcbc_in`)", + "vb1": "Boundary velocity component 1 at domain begin", + "vb2": "Boundary velocity component 2 at domain begin", + "vb3": "Boundary velocity component 3 at domain begin", + "ve1": "Boundary velocity component 1 at domain end", + "ve2": "Boundary velocity component 2 at domain end", + "ve3": "Boundary velocity component 3 at domain end", + }, + "patch_bc": { + "geometry": "Patch shape: 1=line, 2=circle, 3=rectangle", + "type": "BC type applied within patch region", + "dir": "Patch normal direction (1=x, 2=y, 3=z)", + "loc": "Domain boundary (-1=begin, 1=end)", + "centroid": "Patch center coordinate", + "length": "Patch dimension", + "radius": "Patch radius (geometry=2)", + }, + "simplex_params": { + "perturb_dens": "Enable simplex density perturbation", + "perturb_dens_freq": "Density perturbation frequency", + "perturb_dens_scale": "Density perturbation amplitude", + "perturb_dens_offset": "Density perturbation offset seed", + "perturb_vel": "Enable simplex velocity perturbation", + "perturb_vel_freq": "Velocity perturbation frequency", + "perturb_vel_scale": "Velocity perturbation amplitude", + "perturb_vel_offset": "Velocity perturbation offset seed", + }, + "fluid_pp": { + "gamma": "Specific heat ratio (EOS)", + "pi_inf": "Stiffness pressure (EOS)", + "cv": "Specific heat at constant volume", + "qv": "Heat of formation", + "qvp": "Heat of formation derivative", + }, +} + +# Tag → display name for docs. Dict order = priority when a param has multiple tags. +TAG_DISPLAY_NAMES = { + "bubbles": "Bubble model", + "mhd": "MHD", + "chemistry": "Chemistry", + "time": "Time-stepping", + "grid": "Grid", + "weno": "WENO", + "viscosity": "Viscosity", + "elasticity": "Elasticity", + "surface_tension": "Surface tension", + "acoustic": "Acoustic", + "ib": "Immersed boundary", + "probes": "Probe/integral", + "riemann": "Riemann solver", + "relativity": "Relativity", + "output": "Output", + "bc": "Boundary condition", +} + +# Prefix → hint for untagged simple params +PREFIX_HINTS = { + "mixlayer_": "Mixing layer parameter", + "nv_uvm_": "GPU memory management", + "ic_": "Initial condition parameter", +} + + +def _lookup_hint(name): + """Auto-derive constraint hint from HINTS dict using family+attribute matching.""" + if '%' not in name: + # Check PREFIX_HINTS for simple params + for prefix, label in PREFIX_HINTS.items(): + if name.startswith(prefix): + return label + return "" + # Compound name: extract family and attribute + prefix, attr_full = name.split('%', 1) + # Normalize family: "bc_x" → "bc", "patch_bc(1)" → "patch" + family = re.sub(r'[_(].*', '', prefix) + if family not in HINTS: + # Fallback: keep underscores — "patch_bc" → "patch_bc", "simplex_params" → "simplex_params" + m = re.match(r'^[a-zA-Z_]+', prefix) + family = m.group(0) if m else "" + if family not in HINTS: + return "" + # Strip index from attr: "vel_in(1)" → "vel_in" + m = re.match(r'^[a-zA-Z_0-9]+', attr_full) + if not m: + return "" + attr = m.group(0) + return HINTS[family].get(attr, "") + + # ============================================================================ # Schema Validation for Constraints and Dependencies # ============================================================================ # Uses rapidfuzz for "did you mean?" suggestions when typos are detected -_VALID_CONSTRAINT_KEYS = {"choices", "min", "max"} -_VALID_DEPENDENCY_KEYS = {"when_true", "when_set"} -_VALID_CONDITION_KEYS = {"requires", "recommends"} +_VALID_CONSTRAINT_KEYS = {"choices", "min", "max", "value_labels"} +_VALID_DEPENDENCY_KEYS = {"when_true", "when_set", "when_value"} +_VALID_CONDITION_KEYS = {"requires", "recommends", "requires_value"} def _validate_constraint(param_name: str, constraint: Dict[str, Any]) -> None: @@ -380,6 +486,16 @@ def _validate_constraint(param_name: str, constraint: Dict[str, Any]) -> None: raise ValueError(f"Constraint 'min' for '{param_name}' must be a number") if "max" in constraint and not isinstance(constraint["max"], (int, float)): raise ValueError(f"Constraint 'max' for '{param_name}' must be a number") + if "value_labels" in constraint: + if not isinstance(constraint["value_labels"], dict): + raise ValueError(f"Constraint 'value_labels' for '{param_name}' must be a dict") + if "choices" in constraint: + for key in constraint["value_labels"]: + if key not in constraint["choices"]: + raise ValueError( + f"value_labels key {key!r} for '{param_name}' " + f"not in choices {constraint['choices']}" + ) def _validate_dependency(param_name: str, dependency: Dict[str, Any]) -> None: @@ -398,30 +514,55 @@ def _validate_dependency(param_name: str, dependency: Dict[str, Any]) -> None: ) ) - for condition_key in ["when_true", "when_set"]: - if condition_key in dependency: - condition = dependency[condition_key] - if not isinstance(condition, dict): + def _validate_condition(cond_label: str, condition: Any) -> None: + """Validate a condition dict (shared by when_true, when_set, when_value entries).""" + if not isinstance(condition, dict): + raise ValueError( + f"Dependency '{cond_label}' for '{param_name}' must be a dict" + ) + invalid_cond_keys = set(condition.keys()) - _VALID_CONDITION_KEYS + if invalid_cond_keys: + first_invalid = next(iter(invalid_cond_keys)) + raise ValueError( + invalid_key_error( + f"condition in '{cond_label}' for '{param_name}'", + first_invalid, + _VALID_CONDITION_KEYS + ) + ) + for req_key in ["requires", "recommends"]: + if req_key in condition and not isinstance(condition[req_key], list): raise ValueError( - f"Dependency '{condition_key}' for '{param_name}' must be a dict" + f"Dependency '{cond_label}/{req_key}' for '{param_name}' " + "must be a list" ) - invalid_cond_keys = set(condition.keys()) - _VALID_CONDITION_KEYS - if invalid_cond_keys: - first_invalid = next(iter(invalid_cond_keys)) + if "requires_value" in condition: + rv = condition["requires_value"] + if not isinstance(rv, dict): raise ValueError( - invalid_key_error( - f"condition in '{condition_key}' for '{param_name}'", - first_invalid, - _VALID_CONDITION_KEYS - ) + f"Dependency '{cond_label}/requires_value' for '{param_name}' " + "must be a dict" ) - for req_key in ["requires", "recommends"]: - if req_key in condition and not isinstance(condition[req_key], list): + for rv_param, rv_vals in rv.items(): + if not isinstance(rv_vals, list): raise ValueError( - f"Dependency '{condition_key}/{req_key}' for '{param_name}' " - "must be a list" + f"Dependency '{cond_label}/requires_value/{rv_param}' " + f"for '{param_name}' must be a list" ) + for condition_key in ["when_true", "when_set"]: + if condition_key in dependency: + _validate_condition(condition_key, dependency[condition_key]) + + if "when_value" in dependency: + wv = dependency["when_value"] + if not isinstance(wv, dict): + raise ValueError( + f"Dependency 'when_value' for '{param_name}' must be a dict" + ) + for val, condition in wv.items(): + _validate_condition(f"when_value/{val}", condition) + def _validate_all_constraints(constraints: Dict[str, Dict]) -> None: """Validate all constraint definitions.""" @@ -435,36 +576,107 @@ def _validate_all_dependencies(dependencies: Dict[str, Dict]) -> None: _validate_dependency(param_name, dependency) +def get_value_label(param_name: str, value: int) -> str: + """Look up the human-readable label for a parameter's integer code. + + Returns the label string, or ``str(value)`` when no label is defined. + This is the single source of truth for value ↔ label mappings. + """ + constraint = CONSTRAINTS.get(param_name) + if constraint is None: + return str(value) + labels = constraint.get("value_labels") + if labels is None: + return str(value) + return labels.get(value, str(value)) + + # Parameter constraints (choices, min, max) CONSTRAINTS = { # Reconstruction - "weno_order": {"choices": [0, 1, 3, 5, 7]}, # 0 for MUSCL mode - "recon_type": {"choices": [1, 2]}, # 1=WENO, 2=MUSCL - "muscl_order": {"choices": [1, 2]}, - "muscl_lim": {"choices": [1, 2, 3, 4, 5]}, # minmod, MC, Van Albada, Van Leer, SUPERBEE + "weno_order": { + "choices": [0, 1, 3, 5, 7], + "value_labels": {0: "MUSCL mode", 1: "1st order", 3: "WENO3", 5: "WENO5", 7: "WENO7"}, + }, + "recon_type": { + "choices": [1, 2], + "value_labels": {1: "WENO", 2: "MUSCL"}, + }, + "muscl_order": { + "choices": [1, 2], + "value_labels": {1: "1st order", 2: "2nd order"}, + }, + "muscl_lim": { + "choices": [1, 2, 3, 4, 5], + "value_labels": {1: "minmod", 2: "MC", 3: "Van Albada", 4: "Van Leer", 5: "SUPERBEE"}, + }, # Time stepping - "time_stepper": {"choices": [1, 2, 3]}, # 1=Euler, 2=TVD-RK2, 3=TVD-RK3 + "time_stepper": { + "choices": [1, 2, 3], + "value_labels": {1: "RK1 (Forward Euler)", 2: "RK2", 3: "RK3 (SSP)"}, + }, # Riemann solver - "riemann_solver": {"choices": [1, 2, 3, 4, 5]}, # HLL, HLLC, Exact, HLLD, LF - "wave_speeds": {"choices": [1, 2]}, # direct, pressure - "avg_state": {"choices": [1, 2]}, # Roe, arithmetic + "riemann_solver": { + "choices": [1, 2, 3, 4, 5], + "value_labels": {1: "HLL", 2: "HLLC", 3: "Exact", 4: "HLLD", 5: "Lax-Friedrichs"}, + }, + "wave_speeds": { + "choices": [1, 2], + "value_labels": {1: "direct", 2: "pressure"}, + }, + "avg_state": { + "choices": [1, 2], + "value_labels": {1: "Roe", 2: "arithmetic"}, + }, # Model equations - "model_eqns": {"choices": [1, 2, 3, 4]}, # gamma-law, 5-eq, 6-eq, 4-eq + "model_eqns": { + "choices": [1, 2, 3, 4], + "value_labels": {1: "Gamma-law", 2: "5-Equation", 3: "6-Equation", 4: "4-Equation"}, + }, # Bubbles - "bubble_model": {"choices": [1, 2, 3]}, # Gilmore, Keller-Miksis, RP + "bubble_model": { + "choices": [1, 2, 3], + "value_labels": {1: "Gilmore", 2: "Keller-Miksis", 3: "Rayleigh-Plesset"}, + }, # Output - "format": {"choices": [1, 2]}, # Silo, binary - "precision": {"choices": [1, 2]}, # single, double + "format": { + "choices": [1, 2], + "value_labels": {1: "Silo", 2: "binary"}, + }, + "precision": { + "choices": [1, 2], + "value_labels": {1: "single", 2: "double"}, + }, + + # Time stepping (must be positive) + "dt": {"min": 0}, + "t_stop": {"min": 0}, + "t_save": {"min": 0}, + "t_step_save": {"min": 1}, + "t_step_print": {"min": 1}, + "cfl_target": {"min": 0}, + "cfl_max": {"min": 0}, + + # WENO + "weno_eps": {"min": 0}, + + # Physics (must be non-negative) + "R0ref": {"min": 0}, + "sigma": {"min": 0}, # Counts (must be positive) "num_fluids": {"min": 1, "max": 10}, "num_patches": {"min": 0, "max": 10}, "num_ibs": {"min": 0, "max": 10}, + "num_source": {"min": 1}, + "num_probes": {"min": 1}, + "num_integrals": {"min": 1}, + "nb": {"min": 1}, "m": {"min": 0}, "n": {"min": 0}, "p": {"min": 0}, @@ -475,6 +687,18 @@ def _validate_all_dependencies(dependencies: Dict[str, Dict]) -> None: "bubbles_euler": { "when_true": { "recommends": ["nb", "polytropic"], + "requires_value": { + "model_eqns": [2, 4], + "riemann_solver": [2], + "avg_state": [2], + }, + } + }, + "model_eqns": { + "when_value": { + 2: {"requires": ["num_fluids"]}, + 3: {"requires_value": {"riemann_solver": [2], "avg_state": [2], "wave_speeds": [1]}}, + 4: {"requires": ["rhoref", "pref"], "requires_value": {"num_fluids": [1]}}, } }, "viscous": { @@ -484,7 +708,7 @@ def _validate_all_dependencies(dependencies: Dict[str, Dict]) -> None: }, "polydisperse": { "when_true": { - "requires": ["nb"], + "requires": ["nb", "poly_sigma"], } }, "chemistry": { @@ -509,21 +733,105 @@ def _validate_all_dependencies(dependencies: Dict[str, Dict]) -> None: }, "probe_wrt": { "when_true": { - "requires": ["num_probes"], + "requires": ["num_probes", "fd_order"], + } + }, + "stretch_x": { + "when_true": { + "requires": ["a_x", "x_a", "x_b"], + } + }, + "stretch_y": { + "when_true": { + "requires": ["a_y", "y_a", "y_b"], + } + }, + "stretch_z": { + "when_true": { + "requires": ["a_z", "z_a", "z_b"], + } + }, + "bf_x": { + "when_true": { + "requires": ["k_x", "w_x", "p_x", "g_x"], + } + }, + "bf_y": { + "when_true": { + "requires": ["k_y", "w_y", "p_y", "g_y"], + } + }, + "bf_z": { + "when_true": { + "requires": ["k_z", "w_z", "p_z", "g_z"], + } + }, + "teno": { + "when_true": { + "requires": ["teno_CT"], + } + }, + "recon_type": { + "when_value": { + 2: {"recommends": ["muscl_order", "muscl_lim"]}, + } + }, + "surface_tension": { + "when_true": { + "requires": ["sigma"], + } + }, + "mhd": { + "when_true": { + "recommends": ["hyper_cleaning"], + } + }, + "relativity": { + "when_true": { + "requires": ["mhd"], + } + }, + "schlieren_wrt": { + "when_true": { + "requires": ["fd_order"], + } + }, + "cfl_adap_dt": { + "when_true": { + "recommends": ["cfl_target"], + } + }, + "cfl_dt": { + "when_true": { + "recommends": ["cfl_target"], + } + }, + "integral_wrt": { + "when_true": { + "requires": ["fd_order"], } }, } -def _r(name, ptype, tags=None, desc=None): +def _r(name, ptype, tags=None, desc=None, hint=None): """Register a parameter with optional feature tags and description.""" + if hint is None: + hint = _lookup_hint(name) + description = desc if desc else _auto_describe(name) + constraint = CONSTRAINTS.get(name) + if constraint and "value_labels" in constraint: + labels = constraint["value_labels"] + suffix = ", ".join(f"{v}={labels[v]}" for v in sorted(labels)) + description = f"{description} ({suffix})" REGISTRY.register(ParamDef( name=name, param_type=ptype, - description=desc if desc else _auto_describe(name), + description=description, case_optimization=(name in CASE_OPT_PARAMS), - constraints=CONSTRAINTS.get(name), + constraints=constraint, dependencies=DEPENDENCIES.get(name), tags=tags if tags else set(), + hint=hint, )) diff --git a/toolchain/mfc/params/descriptions.py b/toolchain/mfc/params/descriptions.py index 03fbc60ba6..b689e3a536 100644 --- a/toolchain/mfc/params/descriptions.py +++ b/toolchain/mfc/params/descriptions.py @@ -44,7 +44,7 @@ "t_step_print": "Time step interval for printing info", "t_stop": "Simulation stop time", "t_save": "Time interval for saving data", - "time_stepper": "Time integration scheme (1=Euler, 2=TVD-RK2, 3=TVD-RK3)", + "time_stepper": "Time integration scheme", "cfl_adap_dt": "Enable adaptive time stepping based on CFL", "cfl_const_dt": "Use constant CFL for time stepping", "cfl_target": "Target CFL number for adaptive time stepping", @@ -54,7 +54,7 @@ "adap_dt_max_iters": "Maximum iterations for adaptive time stepping", # Model equations - "model_eqns": "Model equations (1=gamma-law, 2=5-eq, 3=6-eq, 4=4-eq)", + "model_eqns": "Model equations", "num_fluids": "Number of fluid components", "num_patches": "Number of initial condition patches", "mpp_lim": "Enable mixture pressure positivity limiter", @@ -62,7 +62,7 @@ "alt_soundspeed": "Use alternative sound speed formulation", # WENO reconstruction - "weno_order": "Order of WENO reconstruction (1, 3, 5, or 7)", + "weno_order": "Order of WENO reconstruction", "weno_eps": "WENO epsilon parameter for smoothness", "mapped_weno": "Enable mapped WENO scheme", "wenoz": "Enable WENO-Z scheme", @@ -75,14 +75,14 @@ "null_weights": "Allow null WENO weights", # MUSCL reconstruction - "recon_type": "Reconstruction type (1=WENO, 2=MUSCL)", + "recon_type": "Reconstruction type", "muscl_order": "Order of MUSCL reconstruction", "muscl_lim": "MUSCL limiter type", # Riemann solver - "riemann_solver": "Riemann solver (1=HLL, 2=HLLC, 3=exact)", - "wave_speeds": "Wave speed estimates (1=direct, 2=pressure)", - "avg_state": "Average state for Riemann solver (1=Roe, 2=arithmetic)", + "riemann_solver": "Riemann solver", + "wave_speeds": "Wave speed estimates", + "avg_state": "Average state for Riemann solver", "low_Mach": "Low Mach number correction", # Boundary conditions @@ -97,7 +97,7 @@ # Physics models "bubbles_euler": "Enable Euler-Euler bubble model", "bubbles_lagrange": "Enable Lagrangian bubble tracking", - "bubble_model": "Bubble dynamics model (1=Gilmore, 2=Keller-Miksis, 3=Rayleigh-Plesset)", + "bubble_model": "Bubble dynamics model", "polytropic": "Enable polytropic gas behavior for bubbles", "polydisperse": "Enable polydisperse bubble distribution", "nb": "Number of bubble bins for polydisperse model", @@ -125,8 +125,8 @@ "integral_wrt": "Write integral data", "parallel_io": "Enable parallel I/O", "file_per_process": "Write separate file per MPI process", - "format": "Output format (1=Silo, 2=binary)", - "precision": "Output precision (1=single, 2=double)", + "format": "Output format", + "precision": "Output precision", "schlieren_wrt": "Write schlieren images", "rho_wrt": "Write density field", "pres_wrt": "Write pressure field", @@ -496,24 +496,27 @@ def get_description(param_name: str) -> str: - """Get description for a parameter from registry or fallback sources.""" - # Primary source: ParamDef.description from registry - from . import REGISTRY # pylint: disable=import-outside-toplevel - param = REGISTRY.all_params.get(param_name) - if param and param.description: - return param.description + """Get description for a parameter from hand-curated or auto-generated sources. - # Fallback 1: manual descriptions dict (legacy, will be removed) + Priority: hand-curated DESCRIPTIONS > PATTERNS > auto-generated param.description. + """ + # 1. Hand-curated descriptions (highest quality) if param_name in DESCRIPTIONS: return DESCRIPTIONS[param_name] - # Fallback 2: pattern matching for indexed params + # 2. Pattern matching for indexed params (hand-curated templates) for pattern, template in PATTERNS: match = re.fullmatch(pattern, param_name) if match: return template.format(*match.groups()) - # Fallback 3: naming convention inference + # 3. Auto-generated description from registry (set by _auto_describe at registration) + from . import REGISTRY # pylint: disable=import-outside-toplevel + param = REGISTRY.all_params.get(param_name) + if param and param.description: + return param.description + + # 4. Last resort: naming convention inference return _infer_from_naming(param_name) diff --git a/toolchain/mfc/params/errors.py b/toolchain/mfc/params/errors.py index ed885f8393..88a2aecf9c 100644 --- a/toolchain/mfc/params/errors.py +++ b/toolchain/mfc/params/errors.py @@ -118,6 +118,37 @@ def dependency_recommendation( return f"When {format_param(param)} is set, consider also setting {format_param(recommended_param)}" +def dependency_value_error( + param: str, + condition: Optional[str], + required_param: str, + expected_values: list, + got: Any, +) -> str: + """ + Create an error for a dependency that requires a specific value. + + Args: + param: Parameter that has the dependency + condition: Condition string (e.g., "=T", "=3") + required_param: Parameter that must have a specific value + expected_values: List of acceptable values + got: Actual value found + + Returns: + Formatted error message. + """ + if condition: + return ( + f"{format_param(param)}{condition} requires {format_param(required_param)} " + f"to be one of {expected_values}, got {format_value(got)}" + ) + return ( + f"{format_param(param)} requires {format_param(required_param)} " + f"to be one of {expected_values}, got {format_value(got)}" + ) + + def required_error(param: str, context: Optional[str] = None) -> str: """ Create a missing required parameter error message. diff --git a/toolchain/mfc/params/generators/docs_gen.py b/toolchain/mfc/params/generators/docs_gen.py index 7699a12c08..6849c8a01d 100644 --- a/toolchain/mfc/params/generators/docs_gen.py +++ b/toolchain/mfc/params/generators/docs_gen.py @@ -12,17 +12,17 @@ from ..schema import ParamType from ..registry import REGISTRY from ..descriptions import get_description +from ..ast_analyzer import analyze_case_validator, classify_message from .. import definitions # noqa: F401 pylint: disable=unused-import def _get_family(name: str) -> str: """Extract family name from parameter (e.g., 'patch_icpp' from 'patch_icpp(1)%vel(1)').""" # Handle indexed parameters - match = re.match(r'^([a-z_]+)', name) + match = re.match(r'^([a-zA-Z_]+)', name) if match: base = match.group(1) - # Check if it's a known family pattern - if any(name.startswith(f"{base}(") or name.startswith(f"{base}%") for _ in [1]): + if name.startswith(f"{base}(") or name.startswith(f"{base}%"): return base return "general" @@ -103,14 +103,20 @@ def _type_to_str(param_type: ParamType) -> str: def _format_constraints(param) -> str: - """Format constraints as readable string.""" + """Format constraints as readable string with value labels when available.""" if not param.constraints: return "" parts = [] c = param.constraints if "choices" in c: - parts.append(f"Values: {c['choices']}") + labels = c.get("value_labels", {}) + if labels: + items = [f"{v}={labels[v]}" if v in labels else str(v) + for v in c["choices"]] + parts.append(", ".join(items)) + else: + parts.append(f"Values: {c['choices']}") if "min" in c: parts.append(f"Min: {c['min']}") if "max" in c: @@ -119,8 +125,218 @@ def _format_constraints(param) -> str: return ", ".join(parts) +def _build_param_name_pattern(): + """Build a regex pattern that matches known parameter names at word boundaries. + + Uses longest-match-first to avoid partial matches (e.g., 'model_eqns' before 'model'). + Only matches names that look like identifiers (avoids matching 'm' inside 'must'). + """ + all_names = sorted(REGISTRY.all_params.keys(), key=len, reverse=True) + # Only include names >= 2 chars to avoid false positives with single-letter params + # and names that are simple identifiers (no % or parens, which need escaping) + safe_names = [n for n in all_names if len(n) >= 2 and re.match(r'^[a-zA-Z_]\w*$', n)] + if not safe_names: + return None + pattern = r'\b(' + '|'.join(re.escape(n) for n in safe_names) + r')\b' + return re.compile(pattern) + + +# Matches compound param names like bub_pp%mu_g, fluid_pp(1)%Re(1), x_output%beg +_COMPOUND_NAME_RE = re.compile(r'\b\w+(?:\([^)]*\))?(?:%\w+(?:\([^)]*\))?)+') + + +def _backtick_params(msg: str, pattern) -> str: + """Wrap parameter names in backticks for markdown rendering. + + Handles three cases in order: + 1. Compound names with % (e.g. bub_pp%mu_g, x_output%beg) + 2. Known registry param names (e.g. model_eqns, weno_order) + 3. Snake_case identifiers not in registry (e.g. cluster_type, smooth_type) + """ + # 1. Wrap compound names (word%word patterns) — must come first + msg = _COMPOUND_NAME_RE.sub(lambda m: f'`{m.group(0)}`', msg) + + # 2. Wrap known simple param names, only outside existing backtick spans + if pattern is not None: + parts = msg.split('`') + for i in range(0, len(parts), 2): + parts[i] = pattern.sub(r'`\1`', parts[i]) + msg = '`'.join(parts) + + # 3. Wrap remaining snake_case identifiers (at least one underscore) + parts = msg.split('`') + for i in range(0, len(parts), 2): + parts[i] = re.sub(r'\b([a-z]\w*_\w+)\b', r'`\1`', parts[i]) + msg = '`'.join(parts) + + return msg + + +def _escape_pct_outside_backticks(text: str) -> str: + """Escape % as %% for Doxygen, but not inside backtick code spans.""" + parts = text.split('`') + for i in range(0, len(parts), 2): + parts[i] = parts[i].replace('%', '%%') + return '`'.join(parts) + + +# Lazily initialized at module level on first use +_PARAM_PATTERN = None + + +def _get_param_pattern(): + global _PARAM_PATTERN # noqa: PLW0603 pylint: disable=global-statement + if _PARAM_PATTERN is None: + _PARAM_PATTERN = _build_param_name_pattern() + return _PARAM_PATTERN + + +def _build_reverse_dep_map() -> Dict[str, List[Tuple[str, str]]]: + """Build map from target param -> [(relation, source_param), ...] from DEPENDENCIES.""" + from ..definitions import DEPENDENCIES # pylint: disable=import-outside-toplevel + reverse: Dict[str, List[Tuple[str, str]]] = {} + for param, dep in DEPENDENCIES.items(): + if "when_true" in dep: + wt = dep["when_true"] + for r in wt.get("requires", []): + reverse.setdefault(r, []).append(("required by", param)) + for r in wt.get("recommends", []): + reverse.setdefault(r, []).append(("recommended for", param)) + if "when_value" in dep: + for val, subspec in dep["when_value"].items(): + for r in subspec.get("requires", []): + reverse.setdefault(r, []).append(("required by", f"{param}={val}")) + for r in subspec.get("recommends", []): + reverse.setdefault(r, []).append(("recommended for", f"{param}={val}")) + return reverse + + +_REVERSE_DEPS = None + + +def _get_reverse_deps(): + global _REVERSE_DEPS # noqa: PLW0603 pylint: disable=global-statement + if _REVERSE_DEPS is None: + _REVERSE_DEPS = _build_reverse_dep_map() + return _REVERSE_DEPS + + +def _format_tag_annotation(param_name: str, param) -> str: # pylint: disable=too-many-locals + """ + Return a short annotation for params with no schema constraints and no AST rules. + + Checks (in order): own DEPENDENCIES, output flag tags, reverse dependencies, + feature tag labels, prefix-group labels, and compound-name attribute annotations. + """ + # 1. Own DEPENDENCIES info + if param.dependencies: + dep = param.dependencies + if "when_true" in dep: + wt = dep["when_true"] + if "requires" in wt: + req = ", ".join(f"`{r}`" for r in wt["requires"]) + return f"Requires {req} when enabled" + if "requires_value" in wt: + parts = [] + for k, vals in wt["requires_value"].items(): + parts.append(f"`{k}` in {vals}") + return "Requires " + ", ".join(parts) + if "recommends" in wt: + rec = ", ".join(f"`{r}`" for r in wt["recommends"]) + return f"Recommends {rec}" + + # 2. Tag-based output flag label (specific labels for LOG output params) + if "output" in param.tags and param.param_type == ParamType.LOG: + if "bubbles" in param.tags: + return "Lagrangian output flag" + if "chemistry" in param.tags: + return "Chemistry output flag" + return "Post-processing output flag" + + # 3. Reverse dependencies (params required/recommended by other features) + reverse = _get_reverse_deps() + if param_name in reverse: + entries = reverse[param_name] + parts = [] + for relation, source in entries[:2]: + parts.append(f"Required by `{source}`" if relation == "required by" + else f"Recommended for `{source}`") + return "; ".join(parts) + + # 4. ParamDef hint (data-driven from definitions.py) + if param.hint: + return param.hint + + # 5. Tag-based label (from TAG_DISPLAY_NAMES in definitions.py) + from ..definitions import TAG_DISPLAY_NAMES # pylint: disable=import-outside-toplevel + for tag, display_name in TAG_DISPLAY_NAMES.items(): + if tag in param.tags: + return f"{display_name} parameter" + + return "" + + +def _format_validator_rules(param_name: str, by_trigger: Dict[str, list], # pylint: disable=too-many-locals + by_param: Dict[str, list] | None = None) -> str: + """Format AST-extracted validator rules for a parameter's Constraints column. + + Gets rules where this param is the trigger. Falls back to by_param + (rules that mention this param) when no trigger rules exist. + """ + rules = by_trigger.get(param_name, []) + if not rules and by_param: + rules = by_param.get(param_name, []) + if not rules: + return "" + + pattern = _get_param_pattern() + + # Deduplicate messages (same message can appear from multiple loop iterations) + seen = set() + unique_rules = [] + for r in rules: + if r.message not in seen: + seen.add(r.message) + unique_rules.append(r) + + # Classify and pick representative messages + requirements = [] + incompatibilities = [] + ranges = [] + others = [] + + for r in unique_rules: + kind = classify_message(r.message) + msg = _backtick_params(r.message, pattern) + if kind == "requirement": + requirements.append(msg) + elif kind == "incompatibility": + incompatibilities.append(msg) + elif kind == "range": + ranges.append(msg) + else: + others.append(msg) + + # Build concise output - show up to 3 rules total, prioritized + parts = [] + budget = 3 + for group in [requirements, incompatibilities, ranges, others]: + for msg in group: + if budget <= 0: + break + parts.append(msg) + budget -= 1 + + return "; ".join(parts) + + def generate_parameter_docs() -> str: # pylint: disable=too-many-locals,too-many-statements """Generate markdown documentation for all parameters.""" + # AST-extract rules from case_validator.py + analysis = analyze_case_validator() + by_trigger = analysis["by_trigger"] + by_param = analysis["by_param"] + lines = [ "@page parameters Case Parameters Reference", "", @@ -225,10 +441,25 @@ def generate_parameter_docs() -> str: # pylint: disable=too-many-locals,too-man # Use pattern view if it reduces rows, otherwise show full table if len(patterns) < len(params): # Pattern view - shows collapsed patterns + # Check if any member of a pattern has constraints + pattern_has_constraints = False + for _pattern, examples in patterns.items(): + for ex in examples: + p = REGISTRY.all_params[ex] + if p.constraints or ex in by_trigger or ex in by_param: + pattern_has_constraints = True + break + if pattern_has_constraints: + break + lines.append("### Patterns") lines.append("") - lines.append("| Pattern | Example | Description |") - lines.append("|---------|---------|-------------|") + if pattern_has_constraints: + lines.append("| Pattern | Example | Description | Constraints |") + lines.append("|---------|---------|-------------|-------------|") + else: + lines.append("| Pattern | Example | Description |") + lines.append("|---------|---------|-------------|") for pattern, examples in sorted(patterns.items()): example = examples[0] @@ -236,29 +467,44 @@ def generate_parameter_docs() -> str: # pylint: disable=too-many-locals,too-man # Truncate long descriptions if len(desc) > 60: desc = desc[:57] + "..." - # Escape % for Doxygen + # Escape % for Doxygen (even inside backtick code spans) pattern_escaped = _escape_percent(pattern) example_escaped = _escape_percent(example) - lines.append(f"| `{pattern_escaped}` | `{example_escaped}` | {desc} |") + desc = _escape_percent(desc) + if pattern_has_constraints: + p = REGISTRY.all_params[example] + constraints = _format_constraints(p) + deps = _format_validator_rules(example, by_trigger, by_param) + extra = "; ".join(filter(None, [constraints, deps])) + if not extra: + extra = _format_tag_annotation(example, p) + extra = _escape_pct_outside_backticks(extra) + lines.append(f"| `{pattern_escaped}` | `{example_escaped}` | {desc} | {extra} |") + else: + lines.append(f"| `{pattern_escaped}` | `{example_escaped}` | {desc} |") lines.append("") else: # Full table - no patterns to collapse - lines.append("| Parameter | Type | Description |") - lines.append("|-----------|------|-------------|") + lines.append("| Parameter | Type | Description | Constraints |") + lines.append("|-----------|------|-------------|-------------|") for name, param in params: type_str = _type_to_str(param.param_type) desc = get_description(name) or "" - constraints = _format_constraints(param) - if constraints: - desc = f"{desc} ({constraints})" if desc else constraints # Truncate long descriptions if len(desc) > 80: desc = desc[:77] + "..." - # Escape % for Doxygen + constraints = _format_constraints(param) + deps = _format_validator_rules(name, by_trigger, by_param) + extra = "; ".join(filter(None, [constraints, deps])) + if not extra: + extra = _format_tag_annotation(name, param) + extra = _escape_pct_outside_backticks(extra) + # Escape % for Doxygen (even inside backtick code spans) name_escaped = _escape_percent(name) - lines.append(f"| `{name_escaped}` | {type_str} | {desc} |") + desc = _escape_percent(desc) + lines.append(f"| `{name_escaped}` | {type_str} | {desc} | {extra} |") lines.append("") diff --git a/toolchain/mfc/params/schema.py b/toolchain/mfc/params/schema.py index 4d282c08e0..8999dc903a 100644 --- a/toolchain/mfc/params/schema.py +++ b/toolchain/mfc/params/schema.py @@ -40,7 +40,7 @@ def json_schema(self) -> Dict[str, Any]: @dataclass -class ParamDef: +class ParamDef: # pylint: disable=too-many-instance-attributes """ Definition of a single MFC parameter. @@ -60,6 +60,7 @@ class ParamDef: constraints: Optional[Dict[str, Any]] = None # {"choices": [...], "min": N, "max": N} dependencies: Optional[Dict[str, Any]] = None # {"requires": [...], "recommends": [...]} tags: Set[str] = field(default_factory=set) # Feature tags: "mhd", "bubbles", etc. + hint: str = "" # Constraint/usage hint for docs (e.g. "Used with grcbc_in") def __post_init__(self): # Validate name diff --git a/toolchain/mfc/params/validate.py b/toolchain/mfc/params/validate.py index b004ed4797..04548b4066 100644 --- a/toolchain/mfc/params/validate.py +++ b/toolchain/mfc/params/validate.py @@ -27,11 +27,12 @@ 3. Physics validation (via case_validator.py) """ -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Optional, Tuple from .registry import REGISTRY from .errors import ( dependency_error, dependency_recommendation, + dependency_value_error, format_error_list, unknown_param_error, ) @@ -95,6 +96,37 @@ def validate_constraints(params: Dict[str, Any]) -> List[str]: return errors +def _check_condition( # pylint: disable=too-many-arguments,too-many-positional-arguments + name: str, + condition: Dict[str, Any], + condition_label: Optional[str], + params: Dict[str, Any], + errors: List[str], + warnings: List[str], +) -> None: + """Check a single condition dict (requires, recommends, requires_value).""" + if "requires" in condition: + for req in condition["requires"]: + if req not in params: + errors.append(dependency_error(name, req, condition_label)) + + if "recommends" in condition: + for rec in condition["recommends"]: + if rec not in params: + warnings.append(dependency_recommendation(name, rec, condition_label)) + + if "requires_value" in condition: + for req_param, expected_vals in condition["requires_value"].items(): + if req_param not in params: + errors.append(dependency_error(name, req_param, condition_label)) + else: + got = params[req_param] + if got not in expected_vals: + errors.append(dependency_value_error( + name, condition_label, req_param, expected_vals, got, + )) + + def check_dependencies(params: Dict[str, Any]) -> Tuple[List[str], List[str]]: # pylint: disable=too-many-branches """ Check parameter dependencies. @@ -119,33 +151,19 @@ def check_dependencies(params: Dict[str, Any]) -> Tuple[List[str], List[str]]: # Check "when_true" dependencies (for LOG params set to "T") if "when_true" in deps and value == "T": - when_true = deps["when_true"] - - # Required params - if "requires" in when_true: - for req in when_true["requires"]: - if req not in params: - errors.append(dependency_error(name, req, "=T")) - - # Recommended params - if "recommends" in when_true: - for rec in when_true["recommends"]: - if rec not in params: - warnings.append(dependency_recommendation(name, rec, "=T")) + _check_condition(name, deps["when_true"], "=T", params, errors, warnings) # Check "when_set" dependencies (for any param that's set) if "when_set" in deps: - when_set = deps["when_set"] - - if "requires" in when_set: - for req in when_set["requires"]: - if req not in params: - errors.append(dependency_error(name, req)) - - if "recommends" in when_set: - for rec in when_set["recommends"]: - if rec not in params: - warnings.append(dependency_recommendation(name, rec)) + _check_condition(name, deps["when_set"], None, params, errors, warnings) + + # Check "when_value" dependencies (value-specific conditions) + if "when_value" in deps: + for trigger_val, condition in deps["when_value"].items(): + if value == trigger_val: + _check_condition( + name, condition, f"={trigger_val}", params, errors, warnings, + ) return errors, warnings