From bcdddd214e1a31be36c1bacbf88ec006ac226b85 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 11:09:39 +0200 Subject: [PATCH 1/6] externalize observation weights from Store layer (closes #28) Remove weight(obs) and is_unweighted() from the ObservationStore trait, delete the ObservationWeights enum, and drop the weights field from FactorMajorStore and ArrayStore. Weights now flow as Option<&[f64]> through the operator/preconditioner builders and live owned at the Solver layer as Option>. - WeightedDesignOperator::new(design, weights) takes the borrowed slice and computes sqrt_weights once; apply/apply_adjoint match on &self.sqrt_weights so the unweighted hot path stays branch-free. - WeightedDesign::rmatvec_wdt(weights, r, x) and the static WeightedDesign::uid_weight(weights, uid) helper replace the prior store-driven branches. gramian_diagonal and its five tests are deleted as dead code (no production callers). - build_preconditioner, build_local_domains, CrossTab::build_for_pair*, and the accumulate_cross_block family all thread Option<&[f64]>. - Solver gains a weights: Option> field; from_design and from_design_with_preconditioner take it. orchestrate::solve / solve_batch and the PyO3 bridge build the owned vector from the optional numpy weights input. Python signatures unchanged. --- crates/within-py/src/lib.rs | 34 ++-- crates/within/benches/fixest.rs | 9 +- crates/within/benches/store_backend.rs | 17 +- crates/within/examples/solve_demo.rs | 5 +- crates/within/src/domain.rs | 105 +++++------- crates/within/src/domain/factor_pairs.rs | 22 +-- crates/within/src/lib.rs | 4 +- crates/within/src/observation.rs | 159 +++++------------- crates/within/src/operator.rs | 17 +- .../within/src/operator/gramian/cross_tab.rs | 21 ++- crates/within/src/operator/gramian/tests.rs | 40 ++--- crates/within/src/operator/preconditioner.rs | 6 +- crates/within/src/operator/tests.rs | 30 ++-- crates/within/src/solver.rs | 41 +++-- crates/within/tests/array_store.rs | 14 +- .../tests/common/orchestrate_helpers.rs | 13 +- crates/within/tests/domain.rs | 140 ++------------- crates/within/tests/edge_cases.rs | 15 +- crates/within/tests/error_paths.rs | 26 +-- crates/within/tests/orchestrate_solve.rs | 28 ++- crates/within/tests/properties.rs | 4 +- crates/within/tests/property_gaps.rs | 6 +- crates/within/tests/solver.rs | 2 +- 23 files changed, 261 insertions(+), 497 deletions(-) diff --git a/crates/within-py/src/lib.rs b/crates/within-py/src/lib.rs index 3c977cd..5476b44 100644 --- a/crates/within-py/src/lib.rs +++ b/crates/within-py/src/lib.rs @@ -44,7 +44,7 @@ use within::config::{ SolverParams, DEFAULT_DENSE_SCHUR_THRESHOLD, }; use within::domain::WeightedDesign; -use within::observation::{FactorMajorStore, ObservationWeights}; +use within::observation::FactorMajorStore; use within::{ solve as solve_native, solve_batch as solve_batch_native, FePreconditioner, Operator, SolveResult, Solver, @@ -700,27 +700,27 @@ impl PySolver { let factor_levels: Vec> = (0..n_factors) .map(|f| cats.column(f).iter().copied().collect()) .collect(); - let w = match &weights { - Some(w) => ObservationWeights::Dense(w.as_array().iter().copied().collect()), - None => ObservationWeights::Unit, - }; - let store = FactorMajorStore::new(factor_levels, w, n_obs).map_err(value_err)?; + let store = FactorMajorStore::new(factor_levels, n_obs).map_err(value_err)?; let design = WeightedDesign::from_store(store).map_err(value_err)?; + let weights_vec: Option> = weights + .as_ref() + .map(|w| w.as_array().iter().copied().collect()); // Handle pre-built FePreconditioner separately (uses a different constructor); // all other variants go through extract_preconditioner_config. - let solver = - if let Some(Ok(fe)) = preconditioner.map(|o| o.downcast::()) { - let fe_precond = fe.get().inner.clone(); - py.allow_threads(|| { - Solver::from_design_with_preconditioner(design, ¶ms, fe_precond) - }) + let solver = if let Some(Ok(fe)) = + preconditioner.map(|o| o.downcast::()) + { + let fe_precond = fe.get().inner.clone(); + py.allow_threads(|| { + Solver::from_design_with_preconditioner(design, weights_vec, ¶ms, fe_precond) + }) + .map_err(value_err)? + } else { + let precond = extract_preconditioner_config(py, preconditioner)?; + py.allow_threads(|| Solver::from_design(design, weights_vec, ¶ms, precond.as_ref())) .map_err(value_err)? - } else { - let precond = extract_preconditioner_config(py, preconditioner)?; - py.allow_threads(|| Solver::from_design(design, ¶ms, precond.as_ref())) - .map_err(value_err)? - }; + }; Ok(Self { solver }) } diff --git a/crates/within/benches/fixest.rs b/crates/within/benches/fixest.rs index d70936f..036617c 100644 --- a/crates/within/benches/fixest.rs +++ b/crates/within/benches/fixest.rs @@ -11,7 +11,7 @@ use within::config::{ ApproxCholConfig, LocalSolverConfig, Preconditioner, ReductionStrategy, SolverParams, }; use within::domain::WeightedDesign; -use within::observation::{FactorMajorStore, ObservationWeights}; +use within::observation::FactorMajorStore; use within::operator::WeightedDesignOperator; use within::Solver; @@ -76,8 +76,7 @@ fn generate_fixest_like_case( vec![indiv_id, year, firm_id] }; - let store = FactorMajorStore::new(factor_levels, ObservationWeights::Unit, case.n_obs) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(factor_levels, case.n_obs).expect("valid factor-major store"); let design = WeightedDesign::from_store(store).expect("valid design"); let mut x_true = vec![0.0; design.n_dofs]; @@ -140,7 +139,7 @@ fn run_lsmr_one_level(design: &WeightedDesign, y: &[f64], ac2: let cfg = one_level_local_solver(ac2); let precond = Preconditioner::Additive(cfg, ReductionStrategy::Auto); let solver = - Solver::from_design(design.clone(), ¶ms, Some(&precond)).expect("solver build"); + Solver::from_design(design.clone(), None, ¶ms, Some(&precond)).expect("solver build"); let _ = solver.solve(y).expect("solve"); } @@ -281,7 +280,7 @@ fn bench_matvec(c: &mut Criterion) { let (design, _y) = generate_fixest_like_case(case, 42); let n_dofs = design.n_dofs; let n_obs = design.n_rows; - let op = WeightedDesignOperator::new(&design); + let op = WeightedDesignOperator::new(&design, None); let x: Vec = (0..n_dofs).map(|i| (i as f64).sin()).collect(); let mut y = vec![0.0; n_obs]; group.bench_function(BenchmarkId::new("apply", &label), |b| { diff --git a/crates/within/benches/store_backend.rs b/crates/within/benches/store_backend.rs index 3f11ecf..572f44c 100644 --- a/crates/within/benches/store_backend.rs +++ b/crates/within/benches/store_backend.rs @@ -9,7 +9,7 @@ use rand::{Rng, SeedableRng}; use within::config::{Preconditioner, SolverParams}; use within::domain::WeightedDesign; -use within::observation::{ArrayStore, FactorMajorStore, ObservationWeights}; +use within::observation::{ArrayStore, FactorMajorStore}; use within::Solver; const TOL: f64 = 1e-6; @@ -100,10 +100,9 @@ fn bench_store_backends(c: &mut Criterion) { let factor_levels: Vec> = (0..p.categories_c.ncols()) .map(|q| p.categories_c.column(q).to_vec()) .collect(); - let store = - FactorMajorStore::new(factor_levels, ObservationWeights::Unit, *n_obs).unwrap(); + let store = FactorMajorStore::new(factor_levels, *n_obs).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); - let solver = Solver::from_design(design, &p.params, precond_ref).unwrap(); + let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); }); @@ -112,10 +111,9 @@ fn bench_store_backends(c: &mut Criterion) { // ArrayStore C-order: zero-copy, strided columns. group.bench_function(BenchmarkId::new("Array(C)", &p.label), |b| { b.iter(|| { - let store = - ArrayStore::new(p.categories_c.view(), ObservationWeights::Unit).unwrap(); + let store = ArrayStore::new(p.categories_c.view()).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); - let solver = Solver::from_design(design, &p.params, precond_ref).unwrap(); + let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); }); @@ -124,10 +122,9 @@ fn bench_store_backends(c: &mut Criterion) { // ArrayStore F-order: zero-copy, contiguous columns. group.bench_function(BenchmarkId::new("Array(F)", &p.label), |b| { b.iter(|| { - let store = - ArrayStore::new(p.categories_f.view(), ObservationWeights::Unit).unwrap(); + let store = ArrayStore::new(p.categories_f.view()).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); - let solver = Solver::from_design(design, &p.params, precond_ref).unwrap(); + let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); }); diff --git a/crates/within/examples/solve_demo.rs b/crates/within/examples/solve_demo.rs index 208b9e5..90959e4 100644 --- a/crates/within/examples/solve_demo.rs +++ b/crates/within/examples/solve_demo.rs @@ -19,12 +19,11 @@ fn main() { } // Build design to compute D * x_true. - use within::observation::{FactorMajorStore, ObservationWeights}; + use within::observation::FactorMajorStore; use within::WeightedDesign; let factor_levels = vec![categories.column(0).to_vec(), categories.column(1).to_vec()]; - let store = FactorMajorStore::new(factor_levels, ObservationWeights::Unit, n_obs) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(factor_levels, n_obs).expect("valid factor-major store"); let design = WeightedDesign::from_store(store).expect("valid design"); // True coefficient vector: x_true[j] = (j mod 7) - 3. diff --git a/crates/within/src/domain.rs b/crates/within/src/domain.rs index e0de963..42005dd 100644 --- a/crates/within/src/domain.rs +++ b/crates/within/src/domain.rs @@ -4,9 +4,11 @@ //! the linear-algebra operators ([`crate::operator`]). It answers two questions: //! //! 1. **What does the design matrix look like?** — [`WeightedDesign`] wraps an -//! [`ObservationStore`] with per-factor metadata ([`FactorMeta`]) and optional -//! observation weights, then provides the core matrix-vector products -//! (`D·x`, `D^T·r`, `D^T·W·r`) needed by every solver path. +//! [`ObservationStore`] with per-factor metadata ([`FactorMeta`]) and +//! provides the core matrix-vector products (`D·x`, `D^T·r`, `D^T·W·r`) +//! needed by every solver path. Observation weights are passed in as +//! `Option<&[f64]>` rather than owned, keeping the design itself purely +//! structural. //! //! 2. **How is the problem decomposed into subdomains?** — The `factor_pairs` //! submodule builds one [`Subdomain`] per connected component of each factor @@ -75,11 +77,10 @@ impl std::fmt::Debug for Subdomain { // Weighted design matrix // =========================================================================== -// Weighted design matrix: `WeightedDesign` generic over `ObservationStore`. -// -// Stores per-factor metadata via `FactorMeta` and delegates observation -// data access to the pluggable store backend `S`. Provides design matrix -// operations (D·x, D^T·r, D^T·W·r, gramian diagonal) as methods. +// `WeightedDesign` is generic over `ObservationStore`. It stores per-factor +// metadata via `FactorMeta` and delegates observation data access to the +// pluggable store backend `S`. Observation weights are not owned here — they +// are passed in as `Option<&[f64]>` to the weight-sensitive matvec methods. use std::sync::atomic::Ordering; use portable_atomic::AtomicF64; @@ -90,10 +91,12 @@ use crate::{WithinError, WithinResult}; /// Weighted fixed-effects design matrix, generic over observation storage. /// -/// `store` holds per-observation data (levels, weights); `factors` holds -/// per-factor metadata (n_levels, offset). +/// `store` holds per-observation factor levels; `factors` holds per-factor +/// metadata (n_levels, offset). Observation weights are **not** stored — they +/// are passed in to methods that need them via `Option<&[f64]>` (where `None` +/// denotes unit weights). pub struct WeightedDesign { - /// Observation storage backend (owns or borrows the raw data). + /// Observation storage backend (owns or borrows the raw factor levels). pub store: S, /// Per-factor metadata: level count and global DOF offset. pub factors: Vec, @@ -152,13 +155,14 @@ impl WeightedDesign { }) } - /// Weight for an observation, respecting the store's weighting mode. + /// Weight for observation `uid` given an optional weight slice. + /// + /// `None` ⇒ 1.0 (unit weights); `Some(w)` ⇒ `w[uid]`. #[inline] - pub fn uid_weight(&self, uid: usize) -> f64 { - if self.store.is_unweighted() { - 1.0 - } else { - self.store.weight(uid) + pub fn uid_weight(weights: Option<&[f64]>, uid: usize) -> f64 { + match weights { + None => 1.0, + Some(w) => w[uid], } } @@ -182,7 +186,7 @@ impl WeightedDesign { } // --------------------------------------------------------------------------- -// Design matrix operations (D·x, D^T·r, D^T·W·r, gramian diagonal) +// Design matrix operations (D·x, D^T·r, D^T·W·r) // --------------------------------------------------------------------------- /// Minimum number of rows before scatter/gather loops are parallelized. @@ -402,30 +406,18 @@ impl WeightedDesign { /// x = D^T·W·r (weighted scatter-add) /// - /// For unweighted stores, this is identical to `rmatvec_dt`. - /// The branch is outside the inner loop. - pub fn rmatvec_wdt(&self, r: &[f64], x: &mut [f64]) { + /// `weights = None` falls through to [`Self::rmatvec_dt`]. Otherwise the + /// per-row factor `w[i] * r[i]` is applied inside the scatter loop. The + /// branch lives outside the inner loop. + pub fn rmatvec_wdt(&self, weights: Option<&[f64]>, r: &[f64], x: &mut [f64]) { debug_assert_eq!(r.len(), self.n_rows); debug_assert_eq!(x.len(), self.n_dofs); - if self.store.is_unweighted() { + let Some(w) = weights else { return self.rmatvec_dt(r, x); - } + }; + debug_assert_eq!(w.len(), self.n_rows); x.fill(0.0); - self.scatter_add(x, |i| self.store.weight(i) * r[i]); - } - - /// Diagonal of D^T·W·D (weighted level counts). - /// - /// Entry `offset_q + j` = sum of weights for observations with level `j` in factor `q`. - pub fn gramian_diagonal(&self) -> Vec { - let mut diag = vec![0.0f64; self.n_dofs]; - let n_obs = self.store.n_obs(); - for (q, f) in self.factors.iter().enumerate() { - for uid in 0..n_obs { - diag[f.offset + self.store.level(uid, q) as usize] += self.uid_weight(uid); - } - } - diag + self.scatter_add(x, |i| w[i] * r[i]); } } @@ -436,19 +428,18 @@ impl WeightedDesign { #[cfg(test)] mod tests { use super::*; - use crate::observation::{FactorMajorStore, ObservationWeights}; + use crate::observation::FactorMajorStore; fn make_test_design() -> WeightedDesign { let categories = vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]]; - let store = FactorMajorStore::new(categories, ObservationWeights::Unit, 5) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(categories, 5).expect("valid factor-major store"); WeightedDesign::from_store(store).expect("valid test design") } - fn make_weighted_design(weights: Vec) -> WeightedDesign { + fn make_test_design_2x2() -> WeightedDesign { let categories = vec![vec![0, 1, 0, 1], vec![0, 0, 1, 1]]; - let store = FactorMajorStore::new(categories, ObservationWeights::Dense(weights), 4) - .expect("valid weighted factor-major store"); + let store = + FactorMajorStore::new(categories, 4).expect("valid weighted factor-major store"); WeightedDesign::from_store(store).expect("valid weighted design") } @@ -474,7 +465,6 @@ mod tests { let dm = make_test_design(); assert_eq!(dm.factors[0].n_levels, 3); assert_eq!(dm.factors[1].n_levels, 4); - // Categories are in the store, not in FactorMeta assert_eq!(dm.store.level(0, 0), 0); assert_eq!(dm.store.level(1, 0), 1); assert_eq!(dm.store.level(2, 0), 2); @@ -510,36 +500,19 @@ mod tests { let mut x_dt = vec![0.0; 7]; let mut x_wdt = vec![0.0; 7]; dm.rmatvec_dt(&r, &mut x_dt); - dm.rmatvec_wdt(&r, &mut x_wdt); + dm.rmatvec_wdt(None, &r, &mut x_wdt); assert_eq!(x_dt, x_wdt); } #[test] fn test_rmatvec_wdt_weighted() { - let dm = make_weighted_design(vec![1.0, 2.0, 3.0, 4.0]); + let dm = make_test_design_2x2(); + let weights = [1.0, 2.0, 3.0, 4.0]; let r = vec![1.0, 1.0, 1.0, 1.0]; let mut x = vec![0.0; 4]; - dm.rmatvec_wdt(&r, &mut x); + dm.rmatvec_wdt(Some(&weights), &r, &mut x); // factor 0: level 0 has obs 0(w=1)+2(w=3)=4, level 1 has obs 1(w=2)+3(w=4)=6 // factor 1: level 0 has obs 0(w=1)+1(w=2)=3, level 1 has obs 2(w=3)+3(w=4)=7 assert_eq!(x, vec![4.0, 6.0, 3.0, 7.0]); } - - #[test] - fn test_gramian_diagonal_unweighted() { - let dm = make_test_design(); - let diag = dm.gramian_diagonal(); - // factor 0: levels [0,1,2,0,1] -> counts [2,2,1] - // factor 1: levels [0,1,2,3,0] -> counts [2,1,1,1] - assert_eq!(diag, vec![2.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0]); - } - - #[test] - fn test_gramian_diagonal_weighted() { - let dm = make_weighted_design(vec![1.0, 2.0, 3.0, 4.0]); - let diag = dm.gramian_diagonal(); - // factor 0: level 0 -> w=1+3=4, level 1 -> w=2+4=6 - // factor 1: level 0 -> w=1+2=3, level 1 -> w=3+4=7 - assert_eq!(diag, vec![4.0, 6.0, 3.0, 7.0]); - } } diff --git a/crates/within/src/domain/factor_pairs.rs b/crates/within/src/domain/factor_pairs.rs index faca19f..d2d9e7f 100644 --- a/crates/within/src/domain/factor_pairs.rs +++ b/crates/within/src/domain/factor_pairs.rs @@ -56,6 +56,7 @@ use crate::operator::gramian::{find_all_active_levels, BipartiteComponent, Cross /// collect. pub(crate) fn build_local_domains( design: &WeightedDesign, + weights: Option<&[f64]>, ) -> Vec<(Subdomain, CrossTab)> { use rayon::prelude::*; @@ -65,7 +66,7 @@ pub(crate) fn build_local_domains( let mut domain_pairs: Vec<(Subdomain, CrossTab)> = pairs .par_iter() - .flat_map(|&(q, r)| domains_for_pair(design, q, r, &all_active)) + .flat_map(|&(q, r)| domains_for_pair(design, weights, q, r, &all_active)) .collect(); compute_partition_weights(&mut domain_pairs, design.n_dofs); @@ -75,14 +76,16 @@ pub(crate) fn build_local_domains( fn domains_for_pair( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, all_active: &[Vec], ) -> Vec<(Subdomain, CrossTab)> { - let (full_ct, l2g) = match CrossTab::build_for_pair_with_active(design, q, r, all_active) { - Some(pair) => pair, - None => return Vec::new(), - }; + let (full_ct, l2g) = + match CrossTab::build_for_pair_with_active(design, weights, q, r, all_active) { + Some(pair) => pair, + None => return Vec::new(), + }; let n_q_full = full_ct.n_q(); split_into_subdomains(full_ct, &l2g, n_q_full, (q, r)) @@ -191,7 +194,7 @@ fn compute_partition_weights(domain_pairs: &mut [(Subdomain, CrossTab)], n_dofs: mod tests { use super::*; use crate::domain::WeightedDesign; - use crate::observation::{FactorMajorStore, ObservationWeights}; + use crate::observation::FactorMajorStore; fn make_test_design() -> WeightedDesign { let store = FactorMajorStore::new( @@ -200,7 +203,6 @@ mod tests { vec![0, 1, 0, 1, 0, 1], vec![0, 0, 1, 1, 0, 1], ], - ObservationWeights::Unit, 6, ) .expect("valid factor-major store"); @@ -210,7 +212,7 @@ mod tests { #[test] fn test_full_cover_domain_count() { let dm = make_test_design(); - let domain_pairs = build_local_domains(&dm); + let domain_pairs = build_local_domains(&dm, None); // 3 factor pairs; each pair may produce multiple components assert!(domain_pairs.len() >= 3); } @@ -218,7 +220,7 @@ mod tests { #[test] fn test_partition_of_unity() { let dm = make_test_design(); - let domain_pairs = build_local_domains(&dm); + let domain_pairs = build_local_domains(&dm, None); let n_dofs = dm.n_dofs; // Two-sided PoU: squared weights must sum to 1 at every DOF. let mut weight_sq_sum = vec![0.0; n_dofs]; @@ -238,7 +240,7 @@ mod tests { #[test] fn test_domains_cover_all_dofs() { let dm = make_test_design(); - let domain_pairs = build_local_domains(&dm); + let domain_pairs = build_local_domains(&dm, None); let mut covered = vec![false; dm.n_dofs]; for (d, _) in &domain_pairs { for &idx in d.core.global_indices() { diff --git a/crates/within/src/lib.rs b/crates/within/src/lib.rs index 0255b76..dc8c90a 100644 --- a/crates/within/src/lib.rs +++ b/crates/within/src/lib.rs @@ -191,9 +191,7 @@ pub use orchestrate::SolveResult; // --------------------------------------------------------------------------- pub use domain::{Subdomain, WeightedDesign}; -pub use observation::{ - ArrayStore, FactorMajorStore, FactorMeta, ObservationStore, ObservationWeights, -}; +pub use observation::{ArrayStore, FactorMajorStore, FactorMeta, ObservationStore}; // --------------------------------------------------------------------------- // Operators & builders diff --git a/crates/within/src/observation.rs b/crates/within/src/observation.rs index df3c3c2..c3b1ed9 100644 --- a/crates/within/src/observation.rs +++ b/crates/within/src/observation.rs @@ -1,8 +1,8 @@ //! Observation storage layer: traits, backends, and metadata. //! //! This is the lowest layer of the `within` crate. It defines *how* -//! per-observation data (factor levels, weights) is stored and accessed, -//! without knowing anything about design matrices, operators, or solvers. +//! per-observation factor-level data is stored and accessed, without knowing +//! anything about design matrices, operators, weights, or solvers. //! //! # Why pluggable backends? //! @@ -31,64 +31,21 @@ //! //! # Key types //! -//! - [`ObservationWeights`] — either unit weights (all 1.0, zero storage) or -//! dense per-observation weights. The `is_unit()` check is hoisted outside -//! inner loops so the hot path sees no per-element branch. //! - [`FactorMeta`] — per-factor metadata (level count and global DOF offset), //! separated from observation data so it can live in the [`WeightedDesign`](crate::domain::WeightedDesign). //! - [`ObservationStore`] — the core trait. All implementors must be //! `Send + Sync` to support Rayon parallelism in the layers above. +//! +//! # Weights +//! +//! Observation weights are intentionally **not** part of this layer. They flow +//! alongside the store as `Option<&[f64]>` (borrowed) or `Option>` +//! (owned at the solver layer), where `None` means "all weights = 1.0". use ndarray::ArrayView2; use crate::error::{WithinError, WithinResult}; -// --------------------------------------------------------------------------- -// ObservationWeights — zero-cost unweighted path -// --------------------------------------------------------------------------- - -/// Observation weights: Unit (all 1.0) or Dense (per-observation). -/// -/// The `is_unit()` check happens *outside* inner loops, so the hot path -/// sees either a constant `1.0` or a sequential array read — no per-element branch. -#[derive(Debug, Clone)] -pub enum ObservationWeights { - /// All weights = 1.0, no storage. - Unit, - /// Per-observation weights. - Dense(Vec), -} - -impl ObservationWeights { - /// Return the weight for observation `obs`. - #[inline] - pub fn get(&self, obs: usize) -> f64 { - match self { - ObservationWeights::Unit => 1.0, - ObservationWeights::Dense(w) => w[obs], - } - } - - /// Returns `true` if all weights are 1.0 (the `Unit` variant). - #[inline] - pub fn is_unit(&self) -> bool { - matches!(self, ObservationWeights::Unit) - } - - /// Validate that this weight vector is compatible with `n_obs` observations. - pub fn validate_for(&self, n_obs: usize) -> WithinResult<()> { - if let ObservationWeights::Dense(w) = self { - if w.len() != n_obs { - return Err(WithinError::WeightCountMismatch { - expected: n_obs, - got: w.len(), - }); - } - } - Ok(()) - } -} - // --------------------------------------------------------------------------- // FactorMeta — per-factor metadata (no observation data) // --------------------------------------------------------------------------- @@ -122,12 +79,6 @@ pub trait ObservationStore: Send + Sync { /// Level index for observation `obs` in factor `factor`. fn level(&self, obs: usize, factor: usize) -> u32; - /// Weight for observation `obs`. - fn weight(&self, obs: usize) -> f64; - - /// Whether all weights are 1.0 (enables optimized unweighted code paths). - fn is_unweighted(&self) -> bool; - /// Optional fast-path access to a factor-major column of levels. /// /// Stores that naturally keep `level(obs, factor)` as contiguous @@ -151,17 +102,12 @@ pub trait ObservationStore: Send + Sync { #[derive(Debug, Clone)] pub struct FactorMajorStore { factor_levels: Vec>, - weights: ObservationWeights, n_obs: usize, } impl FactorMajorStore { /// Create a new factor-major store, validating that all columns have length `n_obs`. - pub fn new( - factor_levels: Vec>, - weights: ObservationWeights, - n_obs: usize, - ) -> WithinResult { + pub fn new(factor_levels: Vec>, n_obs: usize) -> WithinResult { for (factor, col) in factor_levels.iter().enumerate() { if col.len() != n_obs { return Err(WithinError::ObservationCountMismatch { @@ -171,10 +117,8 @@ impl FactorMajorStore { }); } } - weights.validate_for(n_obs)?; Ok(Self { factor_levels, - weights, n_obs, }) } @@ -202,16 +146,6 @@ impl ObservationStore for FactorMajorStore { self.factor_levels[factor][obs] } - #[inline] - fn weight(&self, obs: usize) -> f64 { - self.weights.get(obs) - } - - #[inline] - fn is_unweighted(&self) -> bool { - self.weights.is_unit() - } - #[inline] fn factor_column(&self, factor: usize) -> Option<&[u32]> { Some(self.factor_column(factor)) @@ -235,17 +169,12 @@ impl ObservationStore for FactorMajorStore { #[derive(Debug)] pub struct ArrayStore<'a> { categories: ArrayView2<'a, u32>, - weights: ObservationWeights, } impl<'a> ArrayStore<'a> { - /// Create a zero-copy store from a borrowed 2-D category array and optional weights. - pub fn new(categories: ArrayView2<'a, u32>, weights: ObservationWeights) -> WithinResult { - weights.validate_for(categories.nrows())?; - Ok(Self { - categories, - weights, - }) + /// Create a zero-copy store from a borrowed 2-D category array. + pub fn new(categories: ArrayView2<'a, u32>) -> WithinResult { + Ok(Self { categories }) } } @@ -265,16 +194,6 @@ impl ObservationStore for ArrayStore<'_> { self.categories[[obs, factor]] } - #[inline] - fn weight(&self, obs: usize) -> f64 { - self.weights.get(obs) - } - - #[inline] - fn is_unweighted(&self) -> bool { - self.weights.is_unit() - } - fn factor_column(&self, factor: usize) -> Option<&[u32]> { let strides = self.categories.strides(); // Columns are contiguous only when the row stride is 1 (F-order). @@ -290,6 +209,26 @@ impl ObservationStore for ArrayStore<'_> { } } +// --------------------------------------------------------------------------- +// Weight validation helpers +// --------------------------------------------------------------------------- + +/// Validate that an optional weight slice matches `n_obs` observations. +/// +/// `None` is always valid (interpreted as unit weights). `Some(w)` requires +/// `w.len() == n_obs`. +pub(crate) fn validate_weights(weights: Option<&[f64]>, n_obs: usize) -> WithinResult<()> { + if let Some(w) = weights { + if w.len() != n_obs { + return Err(WithinError::WeightCountMismatch { + expected: n_obs, + got: w.len(), + }); + } + } + Ok(()) +} + // --------------------------------------------------------------------------- // Test helpers // --------------------------------------------------------------------------- @@ -306,39 +245,27 @@ mod tests { #[test] fn test_factor_major_store_basic() { - let store = FactorMajorStore::new(sample_factor_levels(), ObservationWeights::Unit, 4) - .expect("valid factor-major store"); + let store = + FactorMajorStore::new(sample_factor_levels(), 4).expect("valid factor-major store"); assert_eq!(store.n_obs(), 4); assert_eq!(store.n_factors(), 2); assert_eq!(store.level(0, 0), 0); assert_eq!(store.level(1, 0), 1); assert_eq!(store.level(2, 1), 0); - assert_eq!(store.weight(0), 1.0); - assert!(store.is_unweighted()); - } - - #[test] - fn test_factor_major_store_weighted() { - let store = FactorMajorStore::new( - vec![vec![0u32, 1, 2]], - ObservationWeights::Dense(vec![0.5, 1.0, 2.0]), - 3, - ) - .expect("valid weighted factor-major store"); - assert!(!store.is_unweighted()); - assert_eq!(store.weight(0), 0.5); - assert_eq!(store.weight(2), 2.0); } #[test] fn test_factor_column() { - let store = FactorMajorStore::new( - vec![vec![0u32, 1, 2, 0], vec![3, 2, 1, 0]], - ObservationWeights::Unit, - 4, - ) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(vec![vec![0u32, 1, 2, 0], vec![3, 2, 1, 0]], 4) + .expect("valid factor-major store"); assert_eq!(store.factor_column(0), &[0u32, 1, 2, 0]); assert_eq!(store.factor_column(1), &[3u32, 2, 1, 0]); } + + #[test] + fn test_validate_weights() { + assert!(validate_weights(None, 5).is_ok()); + assert!(validate_weights(Some(&[1.0, 2.0, 3.0, 4.0, 5.0]), 5).is_ok()); + assert!(validate_weights(Some(&[1.0, 2.0]), 5).is_err()); + } } diff --git a/crates/within/src/operator.rs b/crates/within/src/operator.rs index 69ef810..cc83ba4 100644 --- a/crates/within/src/operator.rs +++ b/crates/within/src/operator.rs @@ -66,17 +66,12 @@ pub struct WeightedDesignOperator<'a, S: ObservationStore> { } impl<'a, S: ObservationStore> WeightedDesignOperator<'a, S> { - /// Create from a weighted design matrix. - pub fn new(design: &'a WeightedDesign) -> Self { - let sqrt_weights = if design.store.is_unweighted() { - None - } else { - Some( - (0..design.n_rows) - .map(|i| design.uid_weight(i).sqrt()) - .collect(), - ) - }; + /// Create from a weighted design matrix and optional observation weights. + /// + /// `weights = None` selects the unweighted fast-path: `sqrt_weights` is + /// `None` and `apply` / `apply_adjoint` skip the per-row scaling entirely. + pub fn new(design: &'a WeightedDesign, weights: Option<&[f64]>) -> Self { + let sqrt_weights = weights.map(|w| w.iter().map(|wi| wi.sqrt()).collect::>()); Self { scratch: Mutex::new(vec![0.0; design.n_rows]), design, diff --git a/crates/within/src/operator/gramian/cross_tab.rs b/crates/within/src/operator/gramian/cross_tab.rs index 2dd828a..253b8d1 100644 --- a/crates/within/src/operator/gramian/cross_tab.rs +++ b/crates/within/src/operator/gramian/cross_tab.rs @@ -225,12 +225,14 @@ impl CrossTab { #[cfg(test)] pub fn build_for_pair( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, ) -> Option<(Self, Vec)> { let active = find_active_levels(design, q, r)?; - let (c, diag_q, diag_r) = accumulate_cross_block(design, q, r, &active.as_compact_pair()); + let (c, diag_q, diag_r) = + accumulate_cross_block(design, weights, q, r, &active.as_compact_pair()); let ct = c.transpose(); let cross_tab = CrossTab { c, @@ -247,6 +249,7 @@ impl CrossTab { /// active levels have already been determined via `find_all_active_levels`. pub fn build_for_pair_with_active( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, all_active: &[Vec], @@ -255,7 +258,8 @@ impl CrossTab { let fr = &design.factors[r]; let active = build_compact_mapping(&all_active[q], &all_active[r], fq, fr)?; - let (c, diag_q, diag_r) = accumulate_cross_block(design, q, r, &active.as_compact_pair()); + let (c, diag_q, diag_r) = + accumulate_cross_block(design, weights, q, r, &active.as_compact_pair()); let ct = c.transpose(); let cross_tab = CrossTab { c, @@ -406,21 +410,23 @@ impl CrossTab { /// Dispatches to a dense or sparse path based on the table size. fn accumulate_cross_block( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, compact: &CompactPair<'_>, ) -> (CsrBlock, Vec, Vec) { let table_size = compact.n_q * compact.n_r; if table_size <= DENSE_TABLE_MAX_ENTRIES { - accumulate_dense_cross_block(design, q, r, compact) + accumulate_dense_cross_block(design, weights, q, r, compact) } else { - accumulate_sparse_cross_block(design, q, r, compact) + accumulate_sparse_cross_block(design, weights, q, r, compact) } } /// Dense path: flat table with O(1) accumulation per observation (n_q * n_r <= 5M). fn accumulate_dense_cross_block( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, compact: &CompactPair<'_>, @@ -442,7 +448,7 @@ fn accumulate_dense_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = design.uid_weight(uid); + let w = WeightedDesign::::uid_weight(weights, uid); debug_assert!((cj as usize) < n_q && (ck as usize) < n_r); diag_q[cj as usize] += w; diag_r[ck as usize] += w; @@ -460,6 +466,7 @@ fn accumulate_dense_cross_block( /// row. The workspace sort is on unique columns only (n_r_active << len). fn accumulate_sparse_cross_block( design: &WeightedDesign, + weights: Option<&[f64]>, q: usize, r: usize, compact: &CompactPair<'_>, @@ -482,7 +489,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = design.uid_weight(uid); + let w = WeightedDesign::::uid_weight(weights, uid); diag_q[cj as usize] += w; diag_r[ck as usize] += w; row_counts[cj as usize] += 1; @@ -507,7 +514,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = design.uid_weight(uid); + let w = WeightedDesign::::uid_weight(weights, uid); let pos = cursor[cj as usize] as usize; bucket_cols[pos] = ck; bucket_vals[pos] = w; diff --git a/crates/within/src/operator/gramian/tests.rs b/crates/within/src/operator/gramian/tests.rs index 7f05b86..598a387 100644 --- a/crates/within/src/operator/gramian/tests.rs +++ b/crates/within/src/operator/gramian/tests.rs @@ -7,7 +7,7 @@ use proptest::prelude::*; use super::CrossTab; use crate::domain::WeightedDesign; -use crate::observation::{FactorMajorStore, ObservationWeights}; +use crate::observation::FactorMajorStore; use crate::operator::gramian::find_all_active_levels; #[test] @@ -26,29 +26,21 @@ fn test_cross_tab_sparse_accumulation_path() { } // Sparse path (large level counts) - let store_sparse = FactorMajorStore::new( - vec![fa.clone(), fb.clone()], - ObservationWeights::Unit, - n_obs, - ) - .expect("valid sparse store"); + let store_sparse = + FactorMajorStore::new(vec![fa.clone(), fb.clone()], n_obs).expect("valid sparse store"); let design_sparse = WeightedDesign::from_store(store_sparse).expect("valid sparse design"); - let (ct_sparse, _) = - CrossTab::build_for_pair(&design_sparse, 0, 1).expect("sparse cross tab should build"); + let (ct_sparse, _) = CrossTab::build_for_pair(&design_sparse, None, 0, 1) + .expect("sparse cross tab should build"); // Dense path reference: collapse levels to a small range so n_q * n_r <= 5M. // Map each observation to level % 100 for both factors (100*100 = 10 000 <= 5M). let fa_small: Vec = fa.iter().map(|&x| x % 100).collect(); let fb_small: Vec = fb.iter().map(|&x| x % 100).collect(); - let store_dense = FactorMajorStore::new( - vec![fa_small.clone(), fb_small.clone()], - ObservationWeights::Unit, - n_obs, - ) - .expect("valid dense store"); + let store_dense = FactorMajorStore::new(vec![fa_small.clone(), fb_small.clone()], n_obs) + .expect("valid dense store"); let design_dense = WeightedDesign::from_store(store_dense).expect("valid dense design"); let (ct_dense, _) = - CrossTab::build_for_pair(&design_dense, 0, 1).expect("dense cross tab should build"); + CrossTab::build_for_pair(&design_dense, None, 0, 1).expect("dense cross tab should build"); // The sparse CrossTab for the large design should have identical diagonals // to what we'd compute by hand (each observation appears exactly once in the @@ -124,10 +116,9 @@ fn test_extract_component_two_components() { let fa = vec![0u32, 0, 1, 1, 2, 2, 3, 3]; let fb = vec![0u32, 1, 0, 1, 2, 3, 2, 3]; let n_obs = 8; - let store = - FactorMajorStore::new(vec![fa, fb], ObservationWeights::Unit, n_obs).expect("valid store"); + let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); let design = WeightedDesign::from_store(store).expect("valid design"); - let (ct, _) = CrossTab::build_for_pair(&design, 0, 1).expect("cross tab should build"); + let (ct, _) = CrossTab::build_for_pair(&design, None, 0, 1).expect("cross tab should build"); let components = ct.bipartite_connected_components(); assert_eq!(components.len(), 2, "should have 2 connected components"); @@ -235,13 +226,9 @@ proptest! { fb.push((s % n_r as u64) as u32); } - let store = FactorMajorStore::new( - vec![fa, fb], - ObservationWeights::Unit, - n_obs, - ).expect("valid store"); + let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); let design = WeightedDesign::from_store(store).expect("valid design"); - let (ct, _) = CrossTab::build_for_pair(&design, 0, 1) + let (ct, _) = CrossTab::build_for_pair(&design, None, 0, 1) .expect("cross tab should build"); let components = ct.bipartite_connected_components(); @@ -295,8 +282,7 @@ fn test_find_all_active_levels_with_gaps() { let fa = vec![0u32, 2, 4, 0, 2, 4]; let fb = vec![0u32, 1, 2, 0, 1, 2]; let n_obs = 6; - let store = - FactorMajorStore::new(vec![fa, fb], ObservationWeights::Unit, n_obs).expect("valid store"); + let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); let design = WeightedDesign::from_store(store).expect("valid design"); let active = find_all_active_levels(&design); diff --git a/crates/within/src/operator/preconditioner.rs b/crates/within/src/operator/preconditioner.rs index 0f1cb4e..bf0b9cf 100644 --- a/crates/within/src/operator/preconditioner.rs +++ b/crates/within/src/operator/preconditioner.rs @@ -93,16 +93,18 @@ impl Operator for FePreconditioner { } } -/// Build a [`FePreconditioner`] from a design and configuration. +/// Build a [`FePreconditioner`] from a design, optional observation weights, +/// and configuration. pub fn build_preconditioner( design: &WeightedDesign, + weights: Option<&[f64]>, config: &Preconditioner, ) -> WithinResult { use crate::domain::build_local_domains; match config { Preconditioner::Additive(local, reduction) => { - let domains = build_local_domains(design); + let domains = build_local_domains(design, weights); let p = build_additive_with_strategy(domains, design.n_dofs, local, *reduction)?; Ok(FePreconditioner::Additive(p)) } diff --git a/crates/within/src/operator/tests.rs b/crates/within/src/operator/tests.rs index 45d3a1b..0af205e 100644 --- a/crates/within/src/operator/tests.rs +++ b/crates/within/src/operator/tests.rs @@ -4,24 +4,20 @@ mod design_tests { use crate::domain::WeightedDesign; - use crate::observation::{FactorMajorStore, ObservationWeights}; + use crate::observation::FactorMajorStore; use crate::operator::WeightedDesignOperator; use schwarz_precond::Operator; fn make_test_design() -> WeightedDesign { - let store = FactorMajorStore::new( - vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]], - ObservationWeights::Unit, - 5, - ) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]], 5) + .expect("valid factor-major store"); WeightedDesign::from_store(store).expect("valid test design") } #[test] fn test_design_operator_dimensions() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema); + let op = WeightedDesignOperator::new(&schema, None); assert_eq!(op.nrows(), 5); assert_eq!(op.ncols(), 7); } @@ -29,7 +25,7 @@ mod design_tests { #[test] fn test_design_operator_adjoint() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema); + let op = WeightedDesignOperator::new(&schema, None); let x = vec![1.0, -0.5, 2.0, 0.3, -1.0, 0.7, 1.5]; let r = vec![0.1, 0.2, -0.3, 0.4, -0.5]; @@ -48,7 +44,7 @@ mod design_tests { #[test] fn test_matvec_d() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema); + let op = WeightedDesignOperator::new(&schema, None); let x = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0]; let mut y = vec![0.0; 5]; op.apply(&x, &mut y).expect("apply succeeds"); @@ -58,7 +54,7 @@ mod design_tests { #[test] fn test_rmatvec_dt() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema); + let op = WeightedDesignOperator::new(&schema, None); let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let mut x = vec![0.0; 7]; op.apply_adjoint(&r, &mut x) @@ -430,7 +426,7 @@ mod schwarz_tests { ApproxCholConfig, ApproxSchurConfig, LocalSolverConfig, DEFAULT_DENSE_SCHUR_THRESHOLD, }; use crate::domain::{build_local_domains, Subdomain, SubdomainCore, WeightedDesign}; - use crate::observation::{FactorMajorStore, ObservationWeights}; + use crate::observation::FactorMajorStore; use crate::operator::csr_block::CsrBlock; use crate::operator::gramian::CrossTab; use crate::operator::local_solver::BlockElimSolver; @@ -442,14 +438,10 @@ mod schwarz_tests { const BLOCK_ELIM_NESTED_RAYON_CHILD_ENV: &str = "WITHIN_TEST_BLOCK_ELIM_NESTED_RAYON_CHILD"; fn make_test_data() -> (WeightedDesign, Vec<(Subdomain, CrossTab)>) { - let store = FactorMajorStore::new( - vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], - ObservationWeights::Unit, - 5, - ) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], 5) + .expect("valid factor-major store"); let design = WeightedDesign::from_store(store).expect("valid fixed-effects design"); - let domain_pairs = build_local_domains(&design); + let domain_pairs = build_local_domains(&design, None); (design, domain_pairs) } diff --git a/crates/within/src/solver.rs b/crates/within/src/solver.rs index 8c50fe1..ff86bb8 100644 --- a/crates/within/src/solver.rs +++ b/crates/within/src/solver.rs @@ -42,7 +42,7 @@ use crate::operator::WeightedDesignOperator; use crate::config::{Preconditioner, SolverParams}; use crate::domain::WeightedDesign; -use crate::observation::{ArrayStore, ObservationStore, ObservationWeights}; +use crate::observation::{validate_weights, ArrayStore, ObservationStore}; use crate::operator::preconditioner::{build_preconditioner, FePreconditioner}; use crate::orchestrate::{BatchSolveResult, SolveResult}; use crate::WithinResult; @@ -59,6 +59,7 @@ fn norm(v: &[f64]) -> f64 { /// construction time. pub struct Solver { design: WeightedDesign, + weights: Option>, preconditioner: Option, tol: f64, maxiter: usize, @@ -66,19 +67,25 @@ pub struct Solver { } impl Solver { - /// Build from an existing [`WeightedDesign`]. + /// Build from an existing [`WeightedDesign`] and optional observation weights. + /// + /// `weights = None` denotes unit weights (length must match `design.n_rows` + /// when `Some`). pub fn from_design( design: WeightedDesign, + weights: Option>, params: &SolverParams, preconditioner: Option<&Preconditioner>, ) -> WithinResult { + validate_weights(weights.as_deref(), design.n_rows)?; let built_precond = match preconditioner { - Some(config) => Some(build_preconditioner(&design, config)?), + Some(config) => Some(build_preconditioner(&design, weights.as_deref(), config)?), None => None, }; Ok(Self { design, + weights, preconditioner: built_precond, tol: params.tol, maxiter: params.maxiter, @@ -89,11 +96,14 @@ impl Solver { /// Build from a design with a pre-built preconditioner (e.g. deserialized). pub fn from_design_with_preconditioner( design: WeightedDesign, + weights: Option>, params: &SolverParams, preconditioner: FePreconditioner, ) -> WithinResult { + validate_weights(weights.as_deref(), design.n_rows)?; Ok(Self { design, + weights, preconditioner: Some(preconditioner), tol: params.tol, maxiter: params.maxiter, @@ -106,7 +116,7 @@ impl Solver { let t_start = Instant::now(); let t_setup_start = Instant::now(); - let rect_op = WeightedDesignOperator::new(&self.design); + let rect_op = WeightedDesignOperator::new(&self.design, self.weights.as_deref()); let b = rect_op.weighted_rhs(y); let t_solve_start = Instant::now(); @@ -125,11 +135,12 @@ impl Solver { *d = yi - *d; } + let w_ref = self.weights.as_deref(); let mut rhs = vec![0.0; self.design.n_dofs]; - self.design.rmatvec_wdt(y, &mut rhs); + self.design.rmatvec_wdt(w_ref, y, &mut rhs); let rhs_norm = norm(&rhs).max(1e-15); let mut residual_dof = vec![0.0; self.design.n_dofs]; - self.design.rmatvec_wdt(&demeaned, &mut residual_dof); + self.design.rmatvec_wdt(w_ref, &demeaned, &mut residual_dof); let final_residual = norm(&residual_dof) / rhs_norm; Ok(SolveResult { @@ -205,13 +216,10 @@ impl<'a> Solver> { params: &SolverParams, preconditioner: Option<&Preconditioner>, ) -> WithinResult { - let weights = match weights { - Some(w) => ObservationWeights::Dense(w.to_vec()), - None => ObservationWeights::Unit, - }; - let store = ArrayStore::new(categories, weights)?; + let store = ArrayStore::new(categories)?; let design = WeightedDesign::from_store(store)?; - Self::from_design(design, params, preconditioner) + let weights = weights.map(|w| w.to_vec()); + Self::from_design(design, weights, params, preconditioner) } /// Build a solver with a pre-built preconditioner (e.g. deserialized). @@ -221,12 +229,9 @@ impl<'a> Solver> { params: &SolverParams, preconditioner: FePreconditioner, ) -> WithinResult { - let weights = match weights { - Some(w) => ObservationWeights::Dense(w.to_vec()), - None => ObservationWeights::Unit, - }; - let store = ArrayStore::new(categories, weights)?; + let store = ArrayStore::new(categories)?; let design = WeightedDesign::from_store(store)?; - Self::from_design_with_preconditioner(design, params, preconditioner) + let weights = weights.map(|w| w.to_vec()); + Self::from_design_with_preconditioner(design, weights, params, preconditioner) } } diff --git a/crates/within/tests/array_store.rs b/crates/within/tests/array_store.rs index a4d1cdb..53bf3ab 100644 --- a/crates/within/tests/array_store.rs +++ b/crates/within/tests/array_store.rs @@ -1,5 +1,5 @@ use ndarray::{array, Array2, ShapeBuilder}; -use within::observation::{ArrayStore, FactorMajorStore, ObservationStore, ObservationWeights}; +use within::observation::{ArrayStore, FactorMajorStore, ObservationStore}; use within::{solve, Preconditioner, SolverParams, WeightedDesign}; #[path = "common/orchestrate_helpers.rs"] @@ -54,11 +54,11 @@ fn test_array_store_f_contiguous_matches_factor_major() { let factor_cols: Vec> = (0..2) .map(|f| cats.column(f).iter().copied().collect()) .collect(); - let store = FactorMajorStore::new(factor_cols, ObservationWeights::Unit, cats.nrows()) - .expect("valid FactorMajorStore"); + let store = FactorMajorStore::new(factor_cols, cats.nrows()).expect("valid FactorMajorStore"); let design = WeightedDesign::from_store(store).expect("valid design"); - let solver = within::Solver::from_design(design, &default_params(), Some(&additive_precond())) - .expect("solver"); + let solver = + within::Solver::from_design(design, None, &default_params(), Some(&additive_precond())) + .expect("solver"); let result_fms = solver.solve(&y).expect("FactorMajorStore solve"); assert!(result_array.converged); @@ -101,7 +101,7 @@ fn test_array_store_factor_column_f_order() { f.assign(&cats); f }; - let store = ArrayStore::new(cats_f.view(), ObservationWeights::Unit).expect("valid store"); + let store = ArrayStore::new(cats_f.view()).expect("valid store"); assert!(store.factor_column(0).is_some()); assert!(store.factor_column(1).is_some()); } @@ -111,7 +111,7 @@ fn test_array_store_factor_column_c_order() { // C-contiguous array should return None from factor_column() let cats = array![[0u32, 0], [1, 0], [0, 1], [1, 1]]; assert!(cats.is_standard_layout()); // C-contiguous - let store = ArrayStore::new(cats.view(), ObservationWeights::Unit).expect("valid store"); + let store = ArrayStore::new(cats.view()).expect("valid store"); assert!(store.factor_column(0).is_none()); assert!(store.factor_column(1).is_none()); } diff --git a/crates/within/tests/common/orchestrate_helpers.rs b/crates/within/tests/common/orchestrate_helpers.rs index ffe9165..0758e61 100644 --- a/crates/within/tests/common/orchestrate_helpers.rs +++ b/crates/within/tests/common/orchestrate_helpers.rs @@ -1,21 +1,16 @@ #![allow(dead_code)] -use within::{FactorMajorStore, ObservationWeights, SolveResult, WeightedDesign}; +use within::{FactorMajorStore, SolveResult, WeightedDesign}; pub fn make_test_design() -> WeightedDesign { - make_weighted_design( - vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], - ObservationWeights::Unit, - ) - .expect("valid test design") + make_design(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]]).expect("valid test design") } -pub fn make_weighted_design( +pub fn make_design( categories: Vec>, - weights: ObservationWeights, ) -> within::WithinResult> { let n_rows = categories.first().map_or(0, Vec::len); - let store = FactorMajorStore::new(categories, weights, n_rows)?; + let store = FactorMajorStore::new(categories, n_rows)?; WeightedDesign::from_store(store) } diff --git a/crates/within/tests/domain.rs b/crates/within/tests/domain.rs index 1c9f624..912d7a0 100644 --- a/crates/within/tests/domain.rs +++ b/crates/within/tests/domain.rs @@ -1,10 +1,9 @@ //! Integration tests for the domain layer: WeightedDesign operations, -//! adjoint properties, gramian diagonal identity, and convergence through -//! the solve API for designs that exercise partition-of-unity weights and -//! disconnected bipartite structure. +//! adjoint properties, and convergence through the solve API for designs that +//! exercise partition-of-unity weights and disconnected bipartite structure. use proptest::prelude::*; -use within::observation::{FactorMajorStore, ObservationWeights}; +use within::observation::FactorMajorStore; use within::WeightedDesign; // --------------------------------------------------------------------------- @@ -12,8 +11,7 @@ use within::WeightedDesign; // --------------------------------------------------------------------------- fn make_design(categories: Vec>, n_obs: usize) -> WeightedDesign { - let store = FactorMajorStore::new(categories, ObservationWeights::Unit, n_obs) - .expect("valid factor-major store"); + let store = FactorMajorStore::new(categories, n_obs).expect("valid factor-major store"); WeightedDesign::from_store(store).expect("valid design") } @@ -106,96 +104,15 @@ fn test_large_design_rmatvec_dt_correctness() { } } -#[test] -fn test_large_design_gramian_diagonal_from_unit_vectors() { - // Verify: gramian_diagonal()[j] == e_j^T · (D^T · D) · e_j - // which equals ||D · e_j||^2 for unweighted designs. - let dm = make_large_design(); - let n_dofs = dm.n_dofs; - let n_rows = dm.n_rows; - - let diag = dm.gramian_diagonal(); - - for j in 0..n_dofs { - let mut ej = vec![0.0f64; n_dofs]; - ej[j] = 1.0; - - let mut dej = vec![0.0f64; n_rows]; - dm.matvec_d(&ej, &mut dej); - - let norm_sq: f64 = dej.iter().map(|v| v * v).sum(); - assert!( - (diag[j] - norm_sq).abs() < 1e-10, - "gramian_diagonal()[{j}]={} != ||D·e_j||^2={norm_sq}", - diag[j] - ); - } -} - // --------------------------------------------------------------------------- -// 2. Gramian diagonal algebraic identity (property test) +// 2. Weighted adjoint property (proptest) // --------------------------------------------------------------------------- proptest! { #![proptest_config(ProptestConfig::with_cases(10))] - /// For random weights and factor structures, gramian_diagonal()[j] must equal - /// e_j^T · D^T · W · D · e_j for all DOFs j. - #[test] - fn prop_gramian_diagonal_matches_unit_vector_quadratic_form( - n_obs in 20usize..=200, - n_levels_a in 2usize..=15, - n_levels_b in 2usize..=15, - seed in 0u64..1000, - ) { - // Build deterministic-ish factor arrays from seed + indices. - let fa: Vec = (0..n_obs) - .map(|i| ((i * 3 + seed as usize * 7) % n_levels_a) as u32) - .collect(); - let fb: Vec = (0..n_obs) - .map(|i| ((i * 5 + seed as usize * 11) % n_levels_b) as u32) - .collect(); - - // Weights: non-uniform, vary by observation index. - let weights: Vec = (0..n_obs) - .map(|i| 0.5 + (i as f64 * 0.1 + seed as f64 * 0.3).sin().abs()) - .collect(); - - let store = FactorMajorStore::new( - vec![fa, fb], - ObservationWeights::Dense(weights), - n_obs, - ) - .unwrap(); - let dm = WeightedDesign::from_store(store).unwrap(); - - let n_dofs = dm.n_dofs; - let n_rows = dm.n_rows; - let diag = dm.gramian_diagonal(); - - for j in 0..n_dofs { - let mut ej = vec![0.0f64; n_dofs]; - ej[j] = 1.0; - - // D · e_j - let mut dej = vec![0.0f64; n_rows]; - dm.matvec_d(&ej, &mut dej); - - // e_j^T · D^T · W · D · e_j = sum_i w_i * (D·e_j)[i]^2 - let quadratic: f64 = (0..n_rows) - .map(|i| dm.uid_weight(i) * dej[i] * dej[i]) - .sum(); - - prop_assert!( - (diag[j] - quadratic).abs() < 1e-10, - "gramian_diagonal()[{j}]={} != e_j^T·D^T·W·D·e_j={quadratic}", - diag[j] - ); - } - } - /// The adjoint property must hold for random designs: - /// == (i.e., == ) + /// == (i.e., == ) #[test] fn prop_weighted_adjoint_property( n_obs in 20usize..=200, @@ -214,12 +131,7 @@ proptest! { .map(|i| 0.5 + (i as f64 * 0.13 + seed as f64 * 0.41).sin().abs()) .collect(); - let store = FactorMajorStore::new( - vec![fa, fb], - ObservationWeights::Dense(weights.clone()), - n_obs, - ) - .unwrap(); + let store = FactorMajorStore::new(vec![fa, fb], n_obs).unwrap(); let dm = WeightedDesign::from_store(store).unwrap(); let n_dofs = dm.n_dofs; @@ -244,7 +156,7 @@ proptest! { // let mut wdtr = vec![0.0f64; n_dofs]; - dm.rmatvec_wdt(&r, &mut wdtr); + dm.rmatvec_wdt(Some(&weights), &r, &mut wdtr); let rhs: f64 = x.iter().zip(wdtr.iter()).map(|(xi, wi)| xi * wi).sum(); prop_assert!( @@ -275,8 +187,7 @@ fn test_three_factor_design_solve_converges() { let fb: Vec = (0..n_obs).map(|i| ((i / n_lev) % n_lev) as u32).collect(); let fc: Vec = (0..n_obs).map(|i| ((i * 3) % n_lev) as u32).collect(); - let store = FactorMajorStore::new(vec![fa, fb, fc], ObservationWeights::Unit, n_obs) - .expect("valid 3-factor store"); + let store = FactorMajorStore::new(vec![fa, fb, fc], n_obs).expect("valid 3-factor store"); let dm = WeightedDesign::from_store(store).expect("valid 3-factor design"); assert_eq!(dm.n_factors(), 3); @@ -339,12 +250,8 @@ fn test_disconnected_design_larger_converges() { cats[[i, 1]] = fb[i]; } - let store = FactorMajorStore::new( - vec![fa.clone(), fb.clone()], - ObservationWeights::Unit, - n_obs, - ) - .expect("valid disconnected store"); + let store = FactorMajorStore::new(vec![fa.clone(), fb.clone()], n_obs) + .expect("valid disconnected store"); let dm = WeightedDesign::from_store(store).expect("valid disconnected design"); let x_true = vec![1.0f64; dm.n_dofs]; @@ -407,8 +314,7 @@ fn test_disconnected_design_solve_converges() { #[test] fn test_single_factor_design_construction() { let categories = vec![vec![0u32, 1, 2, 0, 1]]; - let store = - FactorMajorStore::new(categories, ObservationWeights::Unit, 5).expect("valid store"); + let store = FactorMajorStore::new(categories, 5).expect("valid store"); let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); assert_eq!(dm.n_factors(), 1, "expected 1 factor"); @@ -419,8 +325,7 @@ fn test_single_factor_design_construction() { #[test] fn test_single_factor_design_adjoint_property() { let categories = vec![vec![0u32, 1, 2, 0, 1]]; - let store = - FactorMajorStore::new(categories, ObservationWeights::Unit, 5).expect("valid store"); + let store = FactorMajorStore::new(categories, 5).expect("valid store"); let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); let n_dofs = dm.n_dofs; @@ -444,19 +349,6 @@ fn test_single_factor_design_adjoint_property() { ); } -#[test] -fn test_single_factor_design_gramian_diagonal_is_level_counts() { - // For unweighted single-factor: gramian_diagonal = observation count per level. - // levels: [0, 1, 2, 0, 1] → counts [2, 2, 1] - let categories = vec![vec![0u32, 1, 2, 0, 1]]; - let store = - FactorMajorStore::new(categories, ObservationWeights::Unit, 5).expect("valid store"); - let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); - - let diag = dm.gramian_diagonal(); - assert_eq!(diag, vec![2.0, 2.0, 1.0]); -} - /// A single-factor design has no factor pairs, so the additive Schwarz /// preconditioner has no subdomains to work with. The solver should still /// function (falling back to unpreconditioned LSMR) or be able to solve the @@ -495,8 +387,7 @@ fn test_single_factor_design_solve_without_precond() { fn test_single_factor_matvec_d_values() { // D·[a, b, c] with levels [0,1,2,0,1] should give [a, b, c, a, b] let categories = vec![vec![0u32, 1, 2, 0, 1]]; - let store = - FactorMajorStore::new(categories, ObservationWeights::Unit, 5).expect("valid store"); + let store = FactorMajorStore::new(categories, 5).expect("valid store"); let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); let x = vec![10.0, 20.0, 30.0]; @@ -509,8 +400,7 @@ fn test_single_factor_matvec_d_values() { fn test_single_factor_rmatvec_dt_values() { // D^T·[1,2,3,4,5] with levels [0,1,2,0,1] should give [1+4, 2+5, 3] = [5, 7, 3] let categories = vec![vec![0u32, 1, 2, 0, 1]]; - let store = - FactorMajorStore::new(categories, ObservationWeights::Unit, 5).expect("valid store"); + let store = FactorMajorStore::new(categories, 5).expect("valid store"); let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; diff --git a/crates/within/tests/edge_cases.rs b/crates/within/tests/edge_cases.rs index 685f721..9a4fe68 100644 --- a/crates/within/tests/edge_cases.rs +++ b/crates/within/tests/edge_cases.rs @@ -1,7 +1,7 @@ use ndarray::array; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; -use within::observation::{FactorMajorStore, ObservationWeights}; +use within::observation::FactorMajorStore; use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign}; #[path = "common/orchestrate_helpers.rs"] @@ -137,7 +137,7 @@ fn test_maxiter_1_partial_result() { (0..n_obs).map(|_| rng.random_range(0..20u32)).collect(), (0..n_obs).map(|_| rng.random_range(0..20u32)).collect(), ]; - let store = FactorMajorStore::new(cats, ObservationWeights::Unit, n_obs).expect("valid store"); + let store = FactorMajorStore::new(cats, n_obs).expect("valid store"); let design = WeightedDesign::from_store(store).expect("valid design"); let y: Vec = (0..n_obs).map(|i| (i as f64 * 0.17).sin()).collect(); @@ -147,7 +147,7 @@ fn test_maxiter_1_partial_result() { maxiter: 1, ..SolverParams::default() }; - let solver = Solver::from_design(design, ¶ms, None).expect("solver build"); + let solver = Solver::from_design(design, None, ¶ms, None).expect("solver build"); let result = solver.solve(&y).expect("solve with maxiter=1"); // Convergence is not expected (tolerance is unreachable in 1 iteration), @@ -184,8 +184,7 @@ fn test_large_design_convergence() { (0..n_obs).map(|_| rng.random_range(0..100u32)).collect(), ]; - let store = - FactorMajorStore::new(cats, ObservationWeights::Unit, n_obs).expect("valid large store"); + let store = FactorMajorStore::new(cats, n_obs).expect("valid large store"); let design = WeightedDesign::from_store(store).expect("valid large design"); let y = common::make_y_from_unit_solution(&design); @@ -194,7 +193,7 @@ fn test_large_design_convergence() { ..SolverParams::default() }; let precond = additive_precond(); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("solver build"); + let solver = Solver::from_design(design, None, ¶ms, Some(&precond)).expect("solver build"); let result = solver.solve(&y).expect("large design solve"); assert!( @@ -217,7 +216,7 @@ fn test_zero_rhs_zero_solution() { let y = vec![0.0f64; design.n_rows]; let params = SolverParams::default(); - let solver = Solver::from_design(design, ¶ms, None).expect("solver build"); + let solver = Solver::from_design(design, None, ¶ms, None).expect("solver build"); let result = solver.solve(&y).expect("zero RHS solve"); assert!(result.converged, "zero RHS should trivially converge"); @@ -278,7 +277,7 @@ fn test_repeated_solve_is_deterministic() { let params = SolverParams::default(); let precond = additive_precond(); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("solver build"); + let solver = Solver::from_design(design, None, ¶ms, Some(&precond)).expect("solver build"); let r1 = solver.solve(&y).expect("first solve"); let r2 = solver.solve(&y).expect("second solve"); diff --git a/crates/within/tests/error_paths.rs b/crates/within/tests/error_paths.rs index bf4e81e..d52801e 100644 --- a/crates/within/tests/error_paths.rs +++ b/crates/within/tests/error_paths.rs @@ -2,14 +2,13 @@ use std::error::Error; use ndarray::Array2; use schwarz_precond::{PreconditionerBuildError, SolveError}; -use within::observation::{FactorMajorStore, ObservationWeights}; -use within::{solve, Preconditioner, SolverParams, WeightedDesign, WithinError}; +use within::observation::FactorMajorStore; +use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign, WithinError}; #[test] fn test_empty_observations_error() { // FactorMajorStore::new allows 0 rows; EmptyObservations is raised by WeightedDesign::from_store - let store = - FactorMajorStore::new(vec![vec![], vec![]], ObservationWeights::Unit, 0).expect("store ok"); + let store = FactorMajorStore::new(vec![vec![], vec![]], 0).expect("store ok"); let result = WeightedDesign::from_store(store); assert!(result.is_err()); match result.unwrap_err() { @@ -21,8 +20,7 @@ fn test_empty_observations_error() { #[test] fn test_observation_count_mismatch_error() { // Factor columns have different lengths - let result = - FactorMajorStore::new(vec![vec![0, 1, 2], vec![0, 1]], ObservationWeights::Unit, 3); + let result = FactorMajorStore::new(vec![vec![0, 1, 2], vec![0, 1]], 3); assert!(result.is_err()); match result.unwrap_err() { WithinError::ObservationCountMismatch { .. } => {} @@ -32,13 +30,15 @@ fn test_observation_count_mismatch_error() { #[test] fn test_weight_count_mismatch_error() { - let result = FactorMajorStore::new( - vec![vec![0, 1, 2], vec![0, 1, 0]], - ObservationWeights::Dense(vec![1.0, 2.0]), // wrong length - 3, - ); - assert!(result.is_err()); - match result.unwrap_err() { + // Weights of wrong length are caught at Solver construction time. + let store = FactorMajorStore::new(vec![vec![0, 1, 2], vec![0, 1, 0]], 3).expect("store ok"); + let design = WeightedDesign::from_store(store).expect("valid design"); + let params = SolverParams::default(); + let result = Solver::from_design(design, Some(vec![1.0, 2.0]), ¶ms, None); + let err = result + .err() + .expect("expected WeightCountMismatch error, got Ok"); + match err { WithinError::WeightCountMismatch { .. } => {} other => panic!("Expected WeightCountMismatch, got: {:?}", other), } diff --git a/crates/within/tests/orchestrate_solve.rs b/crates/within/tests/orchestrate_solve.rs index 4d4ee06..83bb39b 100644 --- a/crates/within/tests/orchestrate_solve.rs +++ b/crates/within/tests/orchestrate_solve.rs @@ -13,7 +13,7 @@ fn test_lsmr_unpreconditioned() { maxiter: 1000, ..Default::default() }; - let solver = Solver::from_design(design, ¶ms, None).expect("build solver"); + let solver = Solver::from_design(design, None, ¶ms, None).expect("build solver"); let result = solver.solve(&y).expect("solve"); common::assert_converged_with_small_residual(&result, 1e-6); } @@ -29,7 +29,7 @@ fn test_lsmr_preconditioned() { ..Default::default() }; let precond = Preconditioner::Additive(LocalSolverConfig::default(), ReductionStrategy::Auto); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("build solver"); + let solver = Solver::from_design(design, None, ¶ms, Some(&precond)).expect("build solver"); let result = solver.solve(&y).expect("solve"); common::assert_converged_with_small_residual(&result, 1e-6); } @@ -44,7 +44,7 @@ fn test_lsmr_least_squares() { maxiter: 1000, ..Default::default() }; - let solver = Solver::from_design(design, ¶ms, None).expect("build solver"); + let solver = Solver::from_design(design, None, ¶ms, None).expect("build solver"); let result = solver.solve(&y).expect("solve"); assert!(result.converged, "LSMR LS did not converge"); common::assert_solution_finite(&result); @@ -52,11 +52,9 @@ fn test_lsmr_least_squares() { #[test] fn test_lsmr_least_squares_weighted_preconditioned() { - let design = common::make_weighted_design( - vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], - within::ObservationWeights::Dense(vec![1.0, 2.0, 1.5, 0.5, 3.0]), - ) - .expect("valid weighted design"); + let design = + common::make_design(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]]).expect("valid design"); + let weights = vec![1.0, 2.0, 1.5, 0.5, 3.0]; let y = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let params = SolverParams { @@ -65,7 +63,8 @@ fn test_lsmr_least_squares_weighted_preconditioned() { ..Default::default() }; let precond = Preconditioner::default(); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("build solver"); + let solver = + Solver::from_design(design, Some(weights), ¶ms, Some(&precond)).expect("build solver"); let result = solver.solve(&y).expect("solve"); common::assert_converged_with_small_residual(&result, 1e-6); common::assert_solution_finite(&result); @@ -73,11 +72,9 @@ fn test_lsmr_least_squares_weighted_preconditioned() { #[test] fn test_lsmr_weighted() { - let design = common::make_weighted_design( - vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], - within::ObservationWeights::Dense(vec![1.0, 2.0, 1.5, 0.5, 3.0]), - ) - .expect("valid weighted design"); + let design = + common::make_design(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]]).expect("valid design"); + let weights = vec![1.0, 2.0, 1.5, 0.5, 3.0]; let y = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let params = SolverParams { @@ -86,7 +83,8 @@ fn test_lsmr_weighted() { ..Default::default() }; let precond = Preconditioner::default(); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("build solver"); + let solver = + Solver::from_design(design, Some(weights), ¶ms, Some(&precond)).expect("build solver"); let result = solver.solve(&y).expect("solve"); common::assert_converged_with_small_residual(&result, 1e-6); common::assert_solution_finite(&result); diff --git a/crates/within/tests/properties.rs b/crates/within/tests/properties.rs index 593d2a7..26f0c83 100644 --- a/crates/within/tests/properties.rs +++ b/crates/within/tests/properties.rs @@ -1,7 +1,7 @@ use ndarray::Array2; use proptest::prelude::*; use schwarz_precond::Operator; -use within::observation::{ArrayStore, ObservationWeights}; +use within::observation::ArrayStore; use within::{solve, FePreconditioner, Preconditioner, SolverParams, WeightedDesign}; /// Generate a random fixed-effects problem as (categories Array2, y Vec). @@ -72,7 +72,7 @@ proptest! { #[test] fn prop_solver_convergence((cats, _y) in random_fe_problem_strategy()) { // Create y = D * x_true so we know the answer - let store = ArrayStore::new(cats.view(), ObservationWeights::Unit).unwrap(); + let store = ArrayStore::new(cats.view()).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); let n_dofs = design.n_dofs; let n_obs = design.n_rows; diff --git a/crates/within/tests/property_gaps.rs b/crates/within/tests/property_gaps.rs index 3a90b84..3087da4 100644 --- a/crates/within/tests/property_gaps.rs +++ b/crates/within/tests/property_gaps.rs @@ -1,6 +1,6 @@ use ndarray::Array2; use proptest::prelude::*; -use within::observation::{ArrayStore, ObservationWeights}; +use within::observation::ArrayStore; use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign}; #[path = "common/orchestrate_helpers.rs"] @@ -104,7 +104,7 @@ proptest! { }; let precond = additive_precond(); // Build the design with a unit-solution RHS so the problem is feasible - let store = ArrayStore::new(cats.view(), ObservationWeights::Unit).unwrap(); + let store = ArrayStore::new(cats.view()).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); let y_feasible: Vec = { let x_true = vec![1.0; design.n_dofs]; @@ -207,7 +207,7 @@ proptest! { #[test] fn prop_single_factor_converges((cats, _y) in single_factor_strategy()) { // Build a consistent RHS: y = D * 1 so the system is exactly solvable. - let store = ArrayStore::new(cats.view(), ObservationWeights::Unit).unwrap(); + let store = ArrayStore::new(cats.view()).unwrap(); let design = WeightedDesign::from_store(store).unwrap(); let n_levels = design.n_dofs; let x_true = vec![1.0; n_levels]; diff --git a/crates/within/tests/solver.rs b/crates/within/tests/solver.rs index 2232e45..63be313 100644 --- a/crates/within/tests/solver.rs +++ b/crates/within/tests/solver.rs @@ -149,7 +149,7 @@ fn test_solver_from_design() { let params = default_params(); let precond = additive_precond(); - let solver = Solver::from_design(design, ¶ms, Some(&precond)).expect("from_design"); + let solver = Solver::from_design(design, None, ¶ms, Some(&precond)).expect("from_design"); let result = solver.solve(&y).expect("solve"); assert!(result.converged); } From c7a305e29d666a0ea3b251a2a675e3e4b72760d2 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 11:11:26 +0200 Subject: [PATCH 2/6] =?UTF-8?q?rename=20ObservationStore=E2=86=92Store,=20?= =?UTF-8?q?WeightedDesign=E2=86=92Design,=20WeightedDesignOperator?= =?UTF-8?q?=E2=86=92DesignOperator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mechanical identifier rename across the workspace plus README and module docstrings. No semantic change; every call site, trait bound, and lib.rs re-export is updated to the new names. --- crates/within-py/src/lib.rs | 4 +- crates/within/README.md | 9 +++-- crates/within/benches/fixest.rs | 17 ++++----- crates/within/benches/store_backend.rs | 8 ++-- crates/within/examples/solve_demo.rs | 4 +- crates/within/src/domain.rs | 36 +++++++++--------- crates/within/src/domain/factor_pairs.rs | 18 ++++----- crates/within/src/lib.rs | 14 +++---- crates/within/src/observation.rs | 16 ++++---- crates/within/src/operator.rs | 18 ++++----- crates/within/src/operator/gramian.rs | 2 +- .../within/src/operator/gramian/cross_tab.rs | 38 +++++++++---------- crates/within/src/operator/gramian/tests.rs | 12 +++--- crates/within/src/operator/preconditioner.rs | 8 ++-- crates/within/src/operator/schwarz.rs | 2 +- crates/within/src/operator/tests.rs | 22 +++++------ crates/within/src/orchestrate.rs | 2 +- crates/within/src/solver.rs | 24 ++++++------ crates/within/tests/array_store.rs | 6 +-- .../tests/common/orchestrate_helpers.rs | 12 +++--- crates/within/tests/domain.rs | 24 ++++++------ crates/within/tests/edge_cases.rs | 6 +-- crates/within/tests/error_paths.rs | 8 ++-- crates/within/tests/properties.rs | 4 +- crates/within/tests/property_gaps.rs | 8 ++-- 25 files changed, 157 insertions(+), 165 deletions(-) diff --git a/crates/within-py/src/lib.rs b/crates/within-py/src/lib.rs index 5476b44..14f0ca3 100644 --- a/crates/within-py/src/lib.rs +++ b/crates/within-py/src/lib.rs @@ -43,7 +43,7 @@ use within::config::{ ApproxCholConfig, ApproxSchurConfig, LocalSolverConfig, Preconditioner, ReductionStrategy, SolverParams, DEFAULT_DENSE_SCHUR_THRESHOLD, }; -use within::domain::WeightedDesign; +use within::domain::Design; use within::observation::FactorMajorStore; use within::{ solve as solve_native, solve_batch as solve_batch_native, FePreconditioner, Operator, @@ -701,7 +701,7 @@ impl PySolver { .map(|f| cats.column(f).iter().copied().collect()) .collect(); let store = FactorMajorStore::new(factor_levels, n_obs).map_err(value_err)?; - let design = WeightedDesign::from_store(store).map_err(value_err)?; + let design = Design::from_store(store).map_err(value_err)?; let weights_vec: Option> = weights .as_ref() .map(|w| w.as_array().iter().copied().collect()); diff --git a/crates/within/README.md b/crates/within/README.md index fa4b2c3..2e8ca65 100644 --- a/crates/within/README.md +++ b/crates/within/README.md @@ -51,14 +51,15 @@ assert!(result.converged); The crate is organized in four layers: -1. **`observation`** — Per-observation factor levels and weights via - `FactorMajorStore` and the `ObservationStore` trait. +1. **`observation`** — Per-observation factor levels via `FactorMajorStore` + and the `Store` trait. Observation weights are not owned here — they + flow as `Option<&[f64]>` to the operator layer. -2. **`domain`** — Domain decomposition. `WeightedDesign` wraps a store with +2. **`domain`** — Domain decomposition. `Design` wraps a store with factor metadata; `build_local_domains` constructs factor-pair subdomains with partition-of-unity weights for the Schwarz preconditioner. -3. **`operator`** — Linear algebra primitives. `WeightedDesignOperator` +3. **`operator`** — Linear algebra primitives. `DesignOperator` (rectangular `sqrt(W) D` for LSMR) and Schwarz preconditioner builders that wire approximate Cholesky local solvers into the generic `schwarz-precond` framework. diff --git a/crates/within/benches/fixest.rs b/crates/within/benches/fixest.rs index 036617c..30403f7 100644 --- a/crates/within/benches/fixest.rs +++ b/crates/within/benches/fixest.rs @@ -10,9 +10,9 @@ use schwarz_precond::Operator; use within::config::{ ApproxCholConfig, LocalSolverConfig, Preconditioner, ReductionStrategy, SolverParams, }; -use within::domain::WeightedDesign; +use within::domain::Design; use within::observation::FactorMajorStore; -use within::operator::WeightedDesignOperator; +use within::operator::DesignOperator; use within::Solver; // =========================================================================== @@ -45,10 +45,7 @@ impl Case { } } -fn generate_fixest_like_case( - case: Case, - seed: u64, -) -> (WeightedDesign, Vec) { +fn generate_fixest_like_case(case: Case, seed: u64) -> (Design, Vec) { let mut rng = SmallRng::seed_from_u64(seed); let n_years = 10usize; let n_indiv_per_firm = 23usize; @@ -77,7 +74,7 @@ fn generate_fixest_like_case( }; let store = FactorMajorStore::new(factor_levels, case.n_obs).expect("valid factor-major store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let mut x_true = vec![0.0; design.n_dofs]; for x in &mut x_true { @@ -122,7 +119,7 @@ fn configure_group<'a>( fn run_smoke( group: &mut BenchmarkGroup<'_, WallTime>, label: &str, - design: &WeightedDesign, + design: &Design, y: &[f64], ) { group.bench_function(BenchmarkId::new(label, ""), |b| { @@ -130,7 +127,7 @@ fn run_smoke( }); } -fn run_lsmr_one_level(design: &WeightedDesign, y: &[f64], ac2: bool) { +fn run_lsmr_one_level(design: &Design, y: &[f64], ac2: bool) { let params = SolverParams { tol: TOL, maxiter: MAXITER, @@ -280,7 +277,7 @@ fn bench_matvec(c: &mut Criterion) { let (design, _y) = generate_fixest_like_case(case, 42); let n_dofs = design.n_dofs; let n_obs = design.n_rows; - let op = WeightedDesignOperator::new(&design, None); + let op = DesignOperator::new(&design, None); let x: Vec = (0..n_dofs).map(|i| (i as f64).sin()).collect(); let mut y = vec![0.0; n_obs]; group.bench_function(BenchmarkId::new("apply", &label), |b| { diff --git a/crates/within/benches/store_backend.rs b/crates/within/benches/store_backend.rs index 572f44c..f802582 100644 --- a/crates/within/benches/store_backend.rs +++ b/crates/within/benches/store_backend.rs @@ -8,7 +8,7 @@ use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use within::config::{Preconditioner, SolverParams}; -use within::domain::WeightedDesign; +use within::domain::Design; use within::observation::{ArrayStore, FactorMajorStore}; use within::Solver; @@ -101,7 +101,7 @@ fn bench_store_backends(c: &mut Criterion) { .map(|q| p.categories_c.column(q).to_vec()) .collect(); let store = FactorMajorStore::new(factor_levels, *n_obs).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); @@ -112,7 +112,7 @@ fn bench_store_backends(c: &mut Criterion) { group.bench_function(BenchmarkId::new("Array(C)", &p.label), |b| { b.iter(|| { let store = ArrayStore::new(p.categories_c.view()).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); @@ -123,7 +123,7 @@ fn bench_store_backends(c: &mut Criterion) { group.bench_function(BenchmarkId::new("Array(F)", &p.label), |b| { b.iter(|| { let store = ArrayStore::new(p.categories_f.view()).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let solver = Solver::from_design(design, None, &p.params, precond_ref).unwrap(); let r = solver.solve(&p.y).unwrap(); assert!(r.converged); diff --git a/crates/within/examples/solve_demo.rs b/crates/within/examples/solve_demo.rs index 90959e4..75f82e5 100644 --- a/crates/within/examples/solve_demo.rs +++ b/crates/within/examples/solve_demo.rs @@ -20,11 +20,11 @@ fn main() { // Build design to compute D * x_true. use within::observation::FactorMajorStore; - use within::WeightedDesign; + use within::Design; let factor_levels = vec![categories.column(0).to_vec(), categories.column(1).to_vec()]; let store = FactorMajorStore::new(factor_levels, n_obs).expect("valid factor-major store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); // True coefficient vector: x_true[j] = (j mod 7) - 3. let total_dofs = design.n_dofs; diff --git a/crates/within/src/domain.rs b/crates/within/src/domain.rs index 42005dd..3c64a25 100644 --- a/crates/within/src/domain.rs +++ b/crates/within/src/domain.rs @@ -3,8 +3,8 @@ //! This module sits between raw observation storage ([`crate::observation`]) and //! the linear-algebra operators ([`crate::operator`]). It answers two questions: //! -//! 1. **What does the design matrix look like?** — [`WeightedDesign`] wraps an -//! [`ObservationStore`] with per-factor metadata ([`FactorMeta`]) and +//! 1. **What does the design matrix look like?** — [`Design`] wraps an +//! [`Store`] with per-factor metadata ([`FactorMeta`]) and //! provides the core matrix-vector products (`D·x`, `D^T·r`, `D^T·W·r`) //! needed by every solver path. Observation weights are passed in as //! `Option<&[f64]>` rather than owned, keeping the design itself purely @@ -77,7 +77,7 @@ impl std::fmt::Debug for Subdomain { // Weighted design matrix // =========================================================================== -// `WeightedDesign` is generic over `ObservationStore`. It stores per-factor +// `Design` is generic over `Store`. It stores per-factor // metadata via `FactorMeta` and delegates observation data access to the // pluggable store backend `S`. Observation weights are not owned here — they // are passed in as `Option<&[f64]>` to the weight-sensitive matvec methods. @@ -86,7 +86,7 @@ use std::sync::atomic::Ordering; use portable_atomic::AtomicF64; use rayon::prelude::*; -use crate::observation::{FactorMeta, ObservationStore}; +use crate::observation::{FactorMeta, Store}; use crate::{WithinError, WithinResult}; /// Weighted fixed-effects design matrix, generic over observation storage. @@ -95,7 +95,7 @@ use crate::{WithinError, WithinResult}; /// metadata (n_levels, offset). Observation weights are **not** stored — they /// are passed in to methods that need them via `Option<&[f64]>` (where `None` /// denotes unit weights). -pub struct WeightedDesign { +pub struct Design { /// Observation storage backend (owns or borrows the raw factor levels). pub store: S, /// Per-factor metadata: level count and global DOF offset. @@ -106,7 +106,7 @@ pub struct WeightedDesign { pub n_dofs: usize, } -impl Clone for WeightedDesign { +impl Clone for Design { fn clone(&self) -> Self { Self { store: self.store.clone(), @@ -117,9 +117,9 @@ impl Clone for WeightedDesign { } } -impl std::fmt::Debug for WeightedDesign { +impl std::fmt::Debug for Design { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WeightedDesign") + f.debug_struct("Design") .field("store", &self.store) .field("factors", &self.factors) .field("n_rows", &self.n_rows) @@ -128,7 +128,7 @@ impl std::fmt::Debug for WeightedDesign WeightedDesign { +impl Design { /// Construct from a store, inferring the number of levels per factor /// from the maximum observed level in each column (`max + 1`). pub fn from_store(store: S) -> WithinResult { @@ -147,7 +147,7 @@ impl WeightedDesign { offset += n_levels; } let n_rows = store.n_obs(); - Ok(WeightedDesign { + Ok(Design { store, factors, n_rows, @@ -211,7 +211,7 @@ enum ScatterStrategy { } #[inline] -fn level_from_column_or_store( +fn level_from_column_or_store( store: &S, levels: Option<&[u32]>, row: usize, @@ -232,7 +232,7 @@ fn level_from_column_or_store( /// All branches share the same per-row `(level, value)` computation via /// `level_value`; only the accumulation strategy differs. #[allow(clippy::too_many_arguments)] -fn scatter_add_single_factor( +fn scatter_add_single_factor( slice: &mut [f64], n_rows: usize, n_levels: usize, @@ -244,7 +244,7 @@ fn scatter_add_single_factor( atomic_buf: &mut Vec, ) { #[inline(always)] - fn level_value( + fn level_value( store: &S, levels: Option<&[u32]>, q: usize, @@ -302,7 +302,7 @@ fn scatter_add_single_factor( } } -impl WeightedDesign { +impl Design { /// Gather-add: `dst[i] += src[offset_q + level(i, q)]` for each factor `q` and row `i`. /// /// This is the core loop of `y = D·x`. Loop order is chosen based on the @@ -430,17 +430,17 @@ mod tests { use super::*; use crate::observation::FactorMajorStore; - fn make_test_design() -> WeightedDesign { + fn make_test_design() -> Design { let categories = vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]]; let store = FactorMajorStore::new(categories, 5).expect("valid factor-major store"); - WeightedDesign::from_store(store).expect("valid test design") + Design::from_store(store).expect("valid test design") } - fn make_test_design_2x2() -> WeightedDesign { + fn make_test_design_2x2() -> Design { let categories = vec![vec![0, 1, 0, 1], vec![0, 0, 1, 1]]; let store = FactorMajorStore::new(categories, 4).expect("valid weighted factor-major store"); - WeightedDesign::from_store(store).expect("valid weighted design") + Design::from_store(store).expect("valid weighted design") } #[test] diff --git a/crates/within/src/domain/factor_pairs.rs b/crates/within/src/domain/factor_pairs.rs index d2d9e7f..589bb85 100644 --- a/crates/within/src/domain/factor_pairs.rs +++ b/crates/within/src/domain/factor_pairs.rs @@ -40,8 +40,8 @@ //! //! [`build_local_domains`] builds subdomains for preconditioner construction. -use super::{PartitionWeights, Subdomain, WeightedDesign}; -use crate::observation::ObservationStore; +use super::{Design, PartitionWeights, Subdomain}; +use crate::observation::Store; use crate::operator::gramian::{find_all_active_levels, BipartiteComponent, CrossTab}; /// Build local subdomains (with pre-built CrossTabs) for pairs of factors. @@ -54,8 +54,8 @@ use crate::operator::gramian::{find_all_active_levels, BipartiteComponent, Cross /// Factor pairs are processed in parallel via Rayon. The /// `compute_partition_weights` step remains sequential after the parallel /// collect. -pub(crate) fn build_local_domains( - design: &WeightedDesign, +pub(crate) fn build_local_domains( + design: &Design, weights: Option<&[f64]>, ) -> Vec<(Subdomain, CrossTab)> { use rayon::prelude::*; @@ -74,8 +74,8 @@ pub(crate) fn build_local_domains( domain_pairs } -fn domains_for_pair( - design: &WeightedDesign, +fn domains_for_pair( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -193,10 +193,10 @@ fn compute_partition_weights(domain_pairs: &mut [(Subdomain, CrossTab)], n_dofs: #[cfg(test)] mod tests { use super::*; - use crate::domain::WeightedDesign; + use crate::domain::Design; use crate::observation::FactorMajorStore; - fn make_test_design() -> WeightedDesign { + fn make_test_design() -> Design { let store = FactorMajorStore::new( vec![ vec![0, 1, 2, 0, 1, 2], @@ -206,7 +206,7 @@ mod tests { 6, ) .expect("valid factor-major store"); - WeightedDesign::from_store(store).expect("valid test design") + Design::from_store(store).expect("valid test design") } #[test] diff --git a/crates/within/src/lib.rs b/crates/within/src/lib.rs index dc8c90a..22147b7 100644 --- a/crates/within/src/lib.rs +++ b/crates/within/src/lib.rs @@ -94,7 +94,7 @@ //! ```text //! within //! ├── observation Storage backends (FactorMajorStore, ArrayStore) -//! ├── domain WeightedDesign + factor-pair subdomains +//! ├── domain Design + factor-pair subdomains //! │ └── factor_pairs Domain construction, partition-of-unity weights //! ├── operator Linear algebra layer //! │ ├── gramian CrossTab (per-pair bipartite blocks) @@ -129,7 +129,7 @@ //! `operator::schur_complement` module implements block-elimination local //! solves on the per-pair `CrossTab` (`operator::gramian::cross_tab`). //! -//! - **Extending with new backends** — The [`ObservationStore`] trait in +//! - **Extending with new backends** — The [`Store`] trait in //! [`observation`] abstracts over how factor-level data is laid out in //! memory. The [`schwarz_precond::LocalSolver`] trait (from the //! `schwarz-precond` crate) governs subdomain solvers. @@ -144,10 +144,10 @@ //! The crate is organized in four layers: //! //! - **`observation`** — Per-observation data storage via [`FactorMajorStore`] -//! and the [`ObservationStore`] trait. -//! - **`domain`** — Domain decomposition: [`WeightedDesign`] wraps a store with factor +//! and the [`Store`] trait. +//! - **`domain`** — Domain decomposition: [`Design`] wraps a store with factor //! metadata; factor-pair subdomains are built with partition-of-unity weights. -//! - **`operator`** — Linear algebra primitives: `WeightedDesignOperator` +//! - **`operator`** — Linear algebra primitives: `DesignOperator` //! (rectangular `sqrt(W) D` for LSMR), per-pair `CrossTab` blocks, and //! Schwarz preconditioner builders. //! - **`orchestrate`** — End-to-end solve: [`solve`] with typed configuration. @@ -190,8 +190,8 @@ pub use orchestrate::SolveResult; // Core types // --------------------------------------------------------------------------- -pub use domain::{Subdomain, WeightedDesign}; -pub use observation::{ArrayStore, FactorMajorStore, FactorMeta, ObservationStore}; +pub use domain::{Design, Subdomain}; +pub use observation::{ArrayStore, FactorMajorStore, FactorMeta, Store}; // --------------------------------------------------------------------------- // Operators & builders diff --git a/crates/within/src/observation.rs b/crates/within/src/observation.rs index c3b1ed9..720b0f7 100644 --- a/crates/within/src/observation.rs +++ b/crates/within/src/observation.rs @@ -13,7 +13,7 @@ //! - **Rust tests and benchmarks** build data programmatically as `Vec>`, //! which is naturally factor-major. //! -//! The [`ObservationStore`] trait abstracts over these layouts so that all +//! The [`Store`] trait abstracts over these layouts so that all //! upstream code (design matrix operations, domain decomposition, Gramian //! assembly) is generic and layout-agnostic. //! @@ -24,7 +24,7 @@ //! | [`FactorMajorStore`] | `factor_levels[q][i]` — grouped by factor | Yes | Rust-native construction; sequential factor-column access for Gramian build and domain decomposition | //! | [`ArrayStore`] | `categories[[i, q]]` — borrowed `ArrayView2` | No (borrows) | Zero-copy from numpy; F-contiguous arrays get contiguous column access matching `FactorMajorStore` performance | //! -//! Both backends implement the optional [`ObservationStore::factor_column`] +//! Both backends implement the optional [`Store::factor_column`] //! fast-path, which returns a contiguous `&[u32]` slice for a factor's levels //! when the memory layout permits it. The design-matrix scatter/gather loops //! exploit this to avoid per-element virtual dispatch. @@ -32,8 +32,8 @@ //! # Key types //! //! - [`FactorMeta`] — per-factor metadata (level count and global DOF offset), -//! separated from observation data so it can live in the [`WeightedDesign`](crate::domain::WeightedDesign). -//! - [`ObservationStore`] — the core trait. All implementors must be +//! separated from observation data so it can live in the [`Design`](crate::domain::Design). +//! - [`Store`] — the core trait. All implementors must be //! `Send + Sync` to support Rayon parallelism in the layers above. //! //! # Weights @@ -62,14 +62,14 @@ pub struct FactorMeta { } // --------------------------------------------------------------------------- -// ObservationStore trait +// Store trait // --------------------------------------------------------------------------- /// Core abstraction: how observation data is stored and accessed. /// /// Each backend optimizes for different data characteristics. /// All implementors must be `Send + Sync` for Rayon parallelism. -pub trait ObservationStore: Send + Sync { +pub trait Store: Send + Sync { /// Number of observations. fn n_obs(&self) -> usize; @@ -130,7 +130,7 @@ impl FactorMajorStore { } } -impl ObservationStore for FactorMajorStore { +impl Store for FactorMajorStore { #[inline] fn n_obs(&self) -> usize { self.n_obs @@ -178,7 +178,7 @@ impl<'a> ArrayStore<'a> { } } -impl ObservationStore for ArrayStore<'_> { +impl Store for ArrayStore<'_> { #[inline] fn n_obs(&self) -> usize { self.categories.nrows() diff --git a/crates/within/src/operator.rs b/crates/within/src/operator.rs index cc83ba4..b5f8707 100644 --- a/crates/within/src/operator.rs +++ b/crates/within/src/operator.rs @@ -12,7 +12,7 @@ //! //! | Operator | Type | Description | //! |---|---|---| -//! | **sqrt(W) D** | [`WeightedDesignOperator`] | Rectangular operator used by LSMR. Implements `sqrt(W) D x` and `D^T sqrt(W) x` via gather/scatter on the observation store | +//! | **sqrt(W) D** | [`DesignOperator`] | Rectangular operator used by LSMR. Implements `sqrt(W) D x` and `D^T sqrt(W) x` via gather/scatter on the observation store | //! //! # Submodules //! @@ -40,15 +40,15 @@ pub(crate) mod schwarz; mod tests; // --------------------------------------------------------------------------- -// WeightedDesignOperator — rectangular, W^{1/2}·D·x / D^T·W^{1/2}·x +// DesignOperator — rectangular, W^{1/2}·D·x / D^T·W^{1/2}·x // --------------------------------------------------------------------------- use std::sync::Mutex; use schwarz_precond::Operator; -use crate::domain::WeightedDesign; -use crate::observation::ObservationStore; +use crate::domain::Design; +use crate::observation::Store; /// Weighted rectangular design operator: `A = W^{1/2} D`. /// @@ -57,20 +57,20 @@ use crate::observation::ObservationStore; /// /// The normal equations of this operator give `A^T A = D^T W D = G` (the Gramian), /// so the existing Schwarz preconditioner approximating `G^{-1}` can be used directly. -pub struct WeightedDesignOperator<'a, S: ObservationStore> { - design: &'a WeightedDesign, +pub struct DesignOperator<'a, S: Store> { + design: &'a Design, /// Pre-computed `sqrt(w_i)` per observation. `None` when unweighted. sqrt_weights: Option>, /// Scratch for the adjoint path: stores `sqrt(w_i) * u_i`. scratch: Mutex>, } -impl<'a, S: ObservationStore> WeightedDesignOperator<'a, S> { +impl<'a, S: Store> DesignOperator<'a, S> { /// Create from a weighted design matrix and optional observation weights. /// /// `weights = None` selects the unweighted fast-path: `sqrt_weights` is /// `None` and `apply` / `apply_adjoint` skip the per-row scaling entirely. - pub fn new(design: &'a WeightedDesign, weights: Option<&[f64]>) -> Self { + pub fn new(design: &'a Design, weights: Option<&[f64]>) -> Self { let sqrt_weights = weights.map(|w| w.iter().map(|wi| wi.sqrt()).collect::>()); Self { scratch: Mutex::new(vec![0.0; design.n_rows]), @@ -89,7 +89,7 @@ impl<'a, S: ObservationStore> WeightedDesignOperator<'a, S> { } } } -impl Operator for WeightedDesignOperator<'_, S> { +impl Operator for DesignOperator<'_, S> { fn nrows(&self) -> usize { self.design.n_rows } diff --git a/crates/within/src/operator/gramian.rs b/crates/within/src/operator/gramian.rs index fa12cf6..e97e1a4 100644 --- a/crates/within/src/operator/gramian.rs +++ b/crates/within/src/operator/gramian.rs @@ -8,7 +8,7 @@ //! //! The LSMR solver does not assemble the full Gramian `G = D^T W D`; it //! works on the rectangular operator `sqrt(W) D` directly via -//! [`crate::operator::WeightedDesignOperator`]. Only per-pair `CrossTab`s are +//! [`crate::operator::DesignOperator`]. Only per-pair `CrossTab`s are //! needed, and only for preconditioner construction. mod cross_tab; diff --git a/crates/within/src/operator/gramian/cross_tab.rs b/crates/within/src/operator/gramian/cross_tab.rs index 253b8d1..531ecf4 100644 --- a/crates/within/src/operator/gramian/cross_tab.rs +++ b/crates/within/src/operator/gramian/cross_tab.rs @@ -27,8 +27,8 @@ //! A `local_to_global` vector maps these back to global DOF indices. use super::super::csr_block::CsrBlock; -use crate::domain::WeightedDesign; -use crate::observation::ObservationStore; +use crate::domain::Design; +use crate::observation::Store; /// Max entries in a flat dense cross-tab accumulator (~40 MB at 8 bytes each). const DENSE_TABLE_MAX_ENTRIES: usize = 5_000_000; @@ -74,7 +74,7 @@ impl ActiveLevels { /// Scan all observations once and mark which levels are active for each factor. /// /// Returns `active[f][level]` = true if any observation uses that level of factor f. -pub fn find_all_active_levels(design: &WeightedDesign) -> Vec> { +pub fn find_all_active_levels(design: &Design) -> Vec> { let n_factors = design.factors.len(); let n_obs = design.store.n_obs(); let mut active: Vec> = design @@ -150,11 +150,7 @@ fn build_compact_mapping( /// /// Returns `None` if either factor has no active levels. #[cfg(test)] -fn find_active_levels( - design: &WeightedDesign, - q: usize, - r: usize, -) -> Option { +fn find_active_levels(design: &Design, q: usize, r: usize) -> Option { let fq = &design.factors[q]; let fr = &design.factors[r]; let n_obs = design.store.n_obs(); @@ -223,8 +219,8 @@ impl CrossTab { /// Also returns `local_to_global`: q-levels first, then r-levels, matching /// the convention used by `ActiveLevels` and `SubdomainCore::global_indices`. #[cfg(test)] - pub fn build_for_pair( - design: &WeightedDesign, + pub fn build_for_pair( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -247,8 +243,8 @@ impl CrossTab { /// /// Like `build_for_pair` but avoids redundant observation scans when /// active levels have already been determined via `find_all_active_levels`. - pub fn build_for_pair_with_active( - design: &WeightedDesign, + pub fn build_for_pair_with_active( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -408,8 +404,8 @@ impl CrossTab { /// `u32::MAX` are skipped. /// /// Dispatches to a dense or sparse path based on the table size. -fn accumulate_cross_block( - design: &WeightedDesign, +fn accumulate_cross_block( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -424,8 +420,8 @@ fn accumulate_cross_block( } /// Dense path: flat table with O(1) accumulation per observation (n_q * n_r <= 5M). -fn accumulate_dense_cross_block( - design: &WeightedDesign, +fn accumulate_dense_cross_block( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -448,7 +444,7 @@ fn accumulate_dense_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = WeightedDesign::::uid_weight(weights, uid); + let w = Design::::uid_weight(weights, uid); debug_assert!((cj as usize) < n_q && (ck as usize) < n_r); diag_q[cj as usize] += w; diag_r[ck as usize] += w; @@ -464,8 +460,8 @@ fn accumulate_dense_cross_block( /// Bucket observations by row in two passes (count + fill), then use /// a dense workspace of size n_r to accumulate and deduplicate each /// row. The workspace sort is on unique columns only (n_r_active << len). -fn accumulate_sparse_cross_block( - design: &WeightedDesign, +fn accumulate_sparse_cross_block( + design: &Design, weights: Option<&[f64]>, q: usize, r: usize, @@ -489,7 +485,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = WeightedDesign::::uid_weight(weights, uid); + let w = Design::::uid_weight(weights, uid); diag_q[cj as usize] += w; diag_r[ck as usize] += w; row_counts[cj as usize] += 1; @@ -514,7 +510,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = WeightedDesign::::uid_weight(weights, uid); + let w = Design::::uid_weight(weights, uid); let pos = cursor[cj as usize] as usize; bucket_cols[pos] = ck; bucket_vals[pos] = w; diff --git a/crates/within/src/operator/gramian/tests.rs b/crates/within/src/operator/gramian/tests.rs index 598a387..5c0bde7 100644 --- a/crates/within/src/operator/gramian/tests.rs +++ b/crates/within/src/operator/gramian/tests.rs @@ -6,7 +6,7 @@ use proptest::prelude::*; use super::CrossTab; -use crate::domain::WeightedDesign; +use crate::domain::Design; use crate::observation::FactorMajorStore; use crate::operator::gramian::find_all_active_levels; @@ -28,7 +28,7 @@ fn test_cross_tab_sparse_accumulation_path() { // Sparse path (large level counts) let store_sparse = FactorMajorStore::new(vec![fa.clone(), fb.clone()], n_obs).expect("valid sparse store"); - let design_sparse = WeightedDesign::from_store(store_sparse).expect("valid sparse design"); + let design_sparse = Design::from_store(store_sparse).expect("valid sparse design"); let (ct_sparse, _) = CrossTab::build_for_pair(&design_sparse, None, 0, 1) .expect("sparse cross tab should build"); @@ -38,7 +38,7 @@ fn test_cross_tab_sparse_accumulation_path() { let fb_small: Vec = fb.iter().map(|&x| x % 100).collect(); let store_dense = FactorMajorStore::new(vec![fa_small.clone(), fb_small.clone()], n_obs) .expect("valid dense store"); - let design_dense = WeightedDesign::from_store(store_dense).expect("valid dense design"); + let design_dense = Design::from_store(store_dense).expect("valid dense design"); let (ct_dense, _) = CrossTab::build_for_pair(&design_dense, None, 0, 1).expect("dense cross tab should build"); @@ -117,7 +117,7 @@ fn test_extract_component_two_components() { let fb = vec![0u32, 1, 0, 1, 2, 3, 2, 3]; let n_obs = 8; let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let (ct, _) = CrossTab::build_for_pair(&design, None, 0, 1).expect("cross tab should build"); let components = ct.bipartite_connected_components(); @@ -227,7 +227,7 @@ proptest! { } let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let (ct, _) = CrossTab::build_for_pair(&design, None, 0, 1) .expect("cross tab should build"); @@ -283,7 +283,7 @@ fn test_find_all_active_levels_with_gaps() { let fb = vec![0u32, 1, 2, 0, 1, 2]; let n_obs = 6; let store = FactorMajorStore::new(vec![fa, fb], n_obs).expect("valid store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let active = find_all_active_levels(&design); diff --git a/crates/within/src/operator/preconditioner.rs b/crates/within/src/operator/preconditioner.rs index bf0b9cf..a7c00bc 100644 --- a/crates/within/src/operator/preconditioner.rs +++ b/crates/within/src/operator/preconditioner.rs @@ -17,8 +17,8 @@ use schwarz_precond::{LocalSolver, Operator, ReductionStrategy}; use serde::{Deserialize, Serialize}; use crate::config::Preconditioner; -use crate::domain::WeightedDesign; -use crate::observation::ObservationStore; +use crate::domain::Design; +use crate::observation::Store; use crate::operator::schwarz::{build_additive_with_strategy, FeSchwarz}; use crate::WithinResult; @@ -95,8 +95,8 @@ impl Operator for FePreconditioner { /// Build a [`FePreconditioner`] from a design, optional observation weights, /// and configuration. -pub fn build_preconditioner( - design: &WeightedDesign, +pub fn build_preconditioner( + design: &Design, weights: Option<&[f64]>, config: &Preconditioner, ) -> WithinResult { diff --git a/crates/within/src/operator/schwarz.rs b/crates/within/src/operator/schwarz.rs index df1a8b9..067e5d8 100644 --- a/crates/within/src/operator/schwarz.rs +++ b/crates/within/src/operator/schwarz.rs @@ -1,6 +1,6 @@ //! Schwarz preconditioner: FE-specific construction helpers. //! -//! This module bridges the fixed-effects domain types ([`WeightedDesign`], +//! This module bridges the fixed-effects domain types ([`Design`], //! [`Subdomain`], `CrossTab`) to the generic `schwarz-precond` crate API. //! The generic crate knows nothing about panel data — it operates on abstract //! [`SubdomainEntry`] values containing a local solver and a set of global DOF diff --git a/crates/within/src/operator/tests.rs b/crates/within/src/operator/tests.rs index 0af205e..7de0844 100644 --- a/crates/within/src/operator/tests.rs +++ b/crates/within/src/operator/tests.rs @@ -3,21 +3,21 @@ // =========================================================================== mod design_tests { - use crate::domain::WeightedDesign; + use crate::domain::Design; use crate::observation::FactorMajorStore; - use crate::operator::WeightedDesignOperator; + use crate::operator::DesignOperator; use schwarz_precond::Operator; - fn make_test_design() -> WeightedDesign { + fn make_test_design() -> Design { let store = FactorMajorStore::new(vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]], 5) .expect("valid factor-major store"); - WeightedDesign::from_store(store).expect("valid test design") + Design::from_store(store).expect("valid test design") } #[test] fn test_design_operator_dimensions() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema, None); + let op = DesignOperator::new(&schema, None); assert_eq!(op.nrows(), 5); assert_eq!(op.ncols(), 7); } @@ -25,7 +25,7 @@ mod design_tests { #[test] fn test_design_operator_adjoint() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema, None); + let op = DesignOperator::new(&schema, None); let x = vec![1.0, -0.5, 2.0, 0.3, -1.0, 0.7, 1.5]; let r = vec![0.1, 0.2, -0.3, 0.4, -0.5]; @@ -44,7 +44,7 @@ mod design_tests { #[test] fn test_matvec_d() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema, None); + let op = DesignOperator::new(&schema, None); let x = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0]; let mut y = vec![0.0; 5]; op.apply(&x, &mut y).expect("apply succeeds"); @@ -54,7 +54,7 @@ mod design_tests { #[test] fn test_rmatvec_dt() { let schema = make_test_design(); - let op = WeightedDesignOperator::new(&schema, None); + let op = DesignOperator::new(&schema, None); let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let mut x = vec![0.0; 7]; op.apply_adjoint(&r, &mut x) @@ -425,7 +425,7 @@ mod schwarz_tests { use crate::config::{ ApproxCholConfig, ApproxSchurConfig, LocalSolverConfig, DEFAULT_DENSE_SCHUR_THRESHOLD, }; - use crate::domain::{build_local_domains, Subdomain, SubdomainCore, WeightedDesign}; + use crate::domain::{build_local_domains, Design, Subdomain, SubdomainCore}; use crate::observation::FactorMajorStore; use crate::operator::csr_block::CsrBlock; use crate::operator::gramian::CrossTab; @@ -437,10 +437,10 @@ mod schwarz_tests { const BLOCK_ELIM_NESTED_RAYON_CHILD_ENV: &str = "WITHIN_TEST_BLOCK_ELIM_NESTED_RAYON_CHILD"; - fn make_test_data() -> (WeightedDesign, Vec<(Subdomain, CrossTab)>) { + fn make_test_data() -> (Design, Vec<(Subdomain, CrossTab)>) { let store = FactorMajorStore::new(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]], 5) .expect("valid factor-major store"); - let design = WeightedDesign::from_store(store).expect("valid fixed-effects design"); + let design = Design::from_store(store).expect("valid fixed-effects design"); let domain_pairs = build_local_domains(&design, None); (design, domain_pairs) } diff --git a/crates/within/src/orchestrate.rs b/crates/within/src/orchestrate.rs index 61b30b7..9abd0f9 100644 --- a/crates/within/src/orchestrate.rs +++ b/crates/within/src/orchestrate.rs @@ -6,7 +6,7 @@ //! ```text //! solve(categories, y, weights, params, preconditioner) //! 1. Validate → observation layer builds an ArrayStore, checks dimensions -//! 2. Design → domain layer wraps the store in a WeightedDesign +//! 2. Design → domain layer wraps the store in a Design //! 3. Precond → operator layer builds subdomains + local solvers (Schwarz) //! 4. Solve → Modified LSMR with the Schwarz preconditioner //! 5. Extract → return coefficients x and demeaned residuals y - Dx diff --git a/crates/within/src/solver.rs b/crates/within/src/solver.rs index ff86bb8..e00c91f 100644 --- a/crates/within/src/solver.rs +++ b/crates/within/src/solver.rs @@ -38,11 +38,11 @@ use ndarray::ArrayView2; use rayon::prelude::*; use schwarz_precond::{lsmr, mlsmr}; -use crate::operator::WeightedDesignOperator; +use crate::operator::DesignOperator; use crate::config::{Preconditioner, SolverParams}; -use crate::domain::WeightedDesign; -use crate::observation::{validate_weights, ArrayStore, ObservationStore}; +use crate::domain::Design; +use crate::observation::{validate_weights, ArrayStore, Store}; use crate::operator::preconditioner::{build_preconditioner, FePreconditioner}; use crate::orchestrate::{BatchSolveResult, SolveResult}; use crate::WithinResult; @@ -57,8 +57,8 @@ fn norm(v: &[f64]) -> f64 { /// [`Solver::solve`] or [`Solver::solve_batch`] repeatedly with different RHS /// vectors. The expensive preconditioner factorization happens only at /// construction time. -pub struct Solver { - design: WeightedDesign, +pub struct Solver { + design: Design, weights: Option>, preconditioner: Option, tol: f64, @@ -66,13 +66,13 @@ pub struct Solver { local_size: Option, } -impl Solver { - /// Build from an existing [`WeightedDesign`] and optional observation weights. +impl Solver { + /// Build from an existing [`Design`] and optional observation weights. /// /// `weights = None` denotes unit weights (length must match `design.n_rows` /// when `Some`). pub fn from_design( - design: WeightedDesign, + design: Design, weights: Option>, params: &SolverParams, preconditioner: Option<&Preconditioner>, @@ -95,7 +95,7 @@ impl Solver { /// Build from a design with a pre-built preconditioner (e.g. deserialized). pub fn from_design_with_preconditioner( - design: WeightedDesign, + design: Design, weights: Option>, params: &SolverParams, preconditioner: FePreconditioner, @@ -116,7 +116,7 @@ impl Solver { let t_start = Instant::now(); let t_setup_start = Instant::now(); - let rect_op = WeightedDesignOperator::new(&self.design, self.weights.as_deref()); + let rect_op = DesignOperator::new(&self.design, self.weights.as_deref()); let b = rect_op.weighted_rhs(y); let t_solve_start = Instant::now(); @@ -217,7 +217,7 @@ impl<'a> Solver> { preconditioner: Option<&Preconditioner>, ) -> WithinResult { let store = ArrayStore::new(categories)?; - let design = WeightedDesign::from_store(store)?; + let design = Design::from_store(store)?; let weights = weights.map(|w| w.to_vec()); Self::from_design(design, weights, params, preconditioner) } @@ -230,7 +230,7 @@ impl<'a> Solver> { preconditioner: FePreconditioner, ) -> WithinResult { let store = ArrayStore::new(categories)?; - let design = WeightedDesign::from_store(store)?; + let design = Design::from_store(store)?; let weights = weights.map(|w| w.to_vec()); Self::from_design_with_preconditioner(design, weights, params, preconditioner) } diff --git a/crates/within/tests/array_store.rs b/crates/within/tests/array_store.rs index 53bf3ab..b60c282 100644 --- a/crates/within/tests/array_store.rs +++ b/crates/within/tests/array_store.rs @@ -1,6 +1,6 @@ use ndarray::{array, Array2, ShapeBuilder}; -use within::observation::{ArrayStore, FactorMajorStore, ObservationStore}; -use within::{solve, Preconditioner, SolverParams, WeightedDesign}; +use within::observation::{ArrayStore, FactorMajorStore, Store}; +use within::{solve, Design, Preconditioner, SolverParams}; #[path = "common/orchestrate_helpers.rs"] mod common; @@ -55,7 +55,7 @@ fn test_array_store_f_contiguous_matches_factor_major() { .map(|f| cats.column(f).iter().copied().collect()) .collect(); let store = FactorMajorStore::new(factor_cols, cats.nrows()).expect("valid FactorMajorStore"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let solver = within::Solver::from_design(design, None, &default_params(), Some(&additive_precond())) .expect("solver"); diff --git a/crates/within/tests/common/orchestrate_helpers.rs b/crates/within/tests/common/orchestrate_helpers.rs index 0758e61..ada62e0 100644 --- a/crates/within/tests/common/orchestrate_helpers.rs +++ b/crates/within/tests/common/orchestrate_helpers.rs @@ -1,21 +1,19 @@ #![allow(dead_code)] -use within::{FactorMajorStore, SolveResult, WeightedDesign}; +use within::{Design, FactorMajorStore, SolveResult}; -pub fn make_test_design() -> WeightedDesign { +pub fn make_test_design() -> Design { make_design(vec![vec![0, 1, 0, 1, 2], vec![0, 0, 1, 1, 0]]).expect("valid test design") } -pub fn make_design( - categories: Vec>, -) -> within::WithinResult> { +pub fn make_design(categories: Vec>) -> within::WithinResult> { let n_rows = categories.first().map_or(0, Vec::len); let store = FactorMajorStore::new(categories, n_rows)?; - WeightedDesign::from_store(store) + Design::from_store(store) } /// Compute y = D * 1 so that the true solution of `min ||y - Dx||^2` is x = 1. -pub fn make_y_from_unit_solution(design: &WeightedDesign) -> Vec { +pub fn make_y_from_unit_solution(design: &Design) -> Vec { let x_true = vec![1.0; design.n_dofs]; let mut y = vec![0.0; design.n_rows]; design.matvec_d(&x_true, &mut y); diff --git a/crates/within/tests/domain.rs b/crates/within/tests/domain.rs index 912d7a0..0e970e1 100644 --- a/crates/within/tests/domain.rs +++ b/crates/within/tests/domain.rs @@ -1,18 +1,18 @@ -//! Integration tests for the domain layer: WeightedDesign operations, +//! Integration tests for the domain layer: Design operations, //! adjoint properties, and convergence through the solve API for designs that //! exercise partition-of-unity weights and disconnected bipartite structure. use proptest::prelude::*; use within::observation::FactorMajorStore; -use within::WeightedDesign; +use within::Design; // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- -fn make_design(categories: Vec>, n_obs: usize) -> WeightedDesign { +fn make_design(categories: Vec>, n_obs: usize) -> Design { let store = FactorMajorStore::new(categories, n_obs).expect("valid factor-major store"); - WeightedDesign::from_store(store).expect("valid design") + Design::from_store(store).expect("valid design") } fn dot(a: &[f64], b: &[f64]) -> f64 { @@ -26,7 +26,7 @@ fn dot(a: &[f64], b: &[f64]) -> f64 { /// Build a 15,000-row design with two factors (~50 levels each). /// This exercises the parallel code paths in `gather_add` (par_chunks_mut) /// and `scatter_add` (Fold strategy: n_rows > 10,000, n_levels < 100,000). -fn make_large_design() -> WeightedDesign { +fn make_large_design() -> Design { let n_obs = 15_000; let n_levels_a = 50usize; let n_levels_b = 50usize; @@ -132,7 +132,7 @@ proptest! { .collect(); let store = FactorMajorStore::new(vec![fa, fb], n_obs).unwrap(); - let dm = WeightedDesign::from_store(store).unwrap(); + let dm = Design::from_store(store).unwrap(); let n_dofs = dm.n_dofs; let n_rows = dm.n_rows; @@ -188,7 +188,7 @@ fn test_three_factor_design_solve_converges() { let fc: Vec = (0..n_obs).map(|i| ((i * 3) % n_lev) as u32).collect(); let store = FactorMajorStore::new(vec![fa, fb, fc], n_obs).expect("valid 3-factor store"); - let dm = WeightedDesign::from_store(store).expect("valid 3-factor design"); + let dm = Design::from_store(store).expect("valid 3-factor design"); assert_eq!(dm.n_factors(), 3); @@ -252,7 +252,7 @@ fn test_disconnected_design_larger_converges() { let store = FactorMajorStore::new(vec![fa.clone(), fb.clone()], n_obs) .expect("valid disconnected store"); - let dm = WeightedDesign::from_store(store).expect("valid disconnected design"); + let dm = Design::from_store(store).expect("valid disconnected design"); let x_true = vec![1.0f64; dm.n_dofs]; let mut y = vec![0.0f64; dm.n_rows]; @@ -315,7 +315,7 @@ fn test_disconnected_design_solve_converges() { fn test_single_factor_design_construction() { let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); - let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); + let dm = Design::from_store(store).expect("valid single-factor design"); assert_eq!(dm.n_factors(), 1, "expected 1 factor"); assert_eq!(dm.n_dofs, 3, "expected 3 DOFs (levels 0,1,2)"); @@ -326,7 +326,7 @@ fn test_single_factor_design_construction() { fn test_single_factor_design_adjoint_property() { let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); - let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); + let dm = Design::from_store(store).expect("valid single-factor design"); let n_dofs = dm.n_dofs; let n_rows = dm.n_rows; @@ -388,7 +388,7 @@ fn test_single_factor_matvec_d_values() { // D·[a, b, c] with levels [0,1,2,0,1] should give [a, b, c, a, b] let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); - let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); + let dm = Design::from_store(store).expect("valid single-factor design"); let x = vec![10.0, 20.0, 30.0]; let mut y = vec![0.0f64; 5]; @@ -401,7 +401,7 @@ fn test_single_factor_rmatvec_dt_values() { // D^T·[1,2,3,4,5] with levels [0,1,2,0,1] should give [1+4, 2+5, 3] = [5, 7, 3] let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); - let dm = WeightedDesign::from_store(store).expect("valid single-factor design"); + let dm = Design::from_store(store).expect("valid single-factor design"); let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let mut x = vec![0.0f64; 3]; diff --git a/crates/within/tests/edge_cases.rs b/crates/within/tests/edge_cases.rs index 9a4fe68..a0cbd0e 100644 --- a/crates/within/tests/edge_cases.rs +++ b/crates/within/tests/edge_cases.rs @@ -2,7 +2,7 @@ use ndarray::array; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use within::observation::FactorMajorStore; -use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign}; +use within::{solve, Design, Preconditioner, Solver, SolverParams}; #[path = "common/orchestrate_helpers.rs"] mod common; @@ -138,7 +138,7 @@ fn test_maxiter_1_partial_result() { (0..n_obs).map(|_| rng.random_range(0..20u32)).collect(), ]; let store = FactorMajorStore::new(cats, n_obs).expect("valid store"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let y: Vec = (0..n_obs).map(|i| (i as f64 * 0.17).sin()).collect(); @@ -185,7 +185,7 @@ fn test_large_design_convergence() { ]; let store = FactorMajorStore::new(cats, n_obs).expect("valid large store"); - let design = WeightedDesign::from_store(store).expect("valid large design"); + let design = Design::from_store(store).expect("valid large design"); let y = common::make_y_from_unit_solution(&design); let params = SolverParams { diff --git a/crates/within/tests/error_paths.rs b/crates/within/tests/error_paths.rs index d52801e..6bb3742 100644 --- a/crates/within/tests/error_paths.rs +++ b/crates/within/tests/error_paths.rs @@ -3,13 +3,13 @@ use std::error::Error; use ndarray::Array2; use schwarz_precond::{PreconditionerBuildError, SolveError}; use within::observation::FactorMajorStore; -use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign, WithinError}; +use within::{solve, Design, Preconditioner, Solver, SolverParams, WithinError}; #[test] fn test_empty_observations_error() { - // FactorMajorStore::new allows 0 rows; EmptyObservations is raised by WeightedDesign::from_store + // FactorMajorStore::new allows 0 rows; EmptyObservations is raised by Design::from_store let store = FactorMajorStore::new(vec![vec![], vec![]], 0).expect("store ok"); - let result = WeightedDesign::from_store(store); + let result = Design::from_store(store); assert!(result.is_err()); match result.unwrap_err() { WithinError::EmptyObservations => {} @@ -32,7 +32,7 @@ fn test_observation_count_mismatch_error() { fn test_weight_count_mismatch_error() { // Weights of wrong length are caught at Solver construction time. let store = FactorMajorStore::new(vec![vec![0, 1, 2], vec![0, 1, 0]], 3).expect("store ok"); - let design = WeightedDesign::from_store(store).expect("valid design"); + let design = Design::from_store(store).expect("valid design"); let params = SolverParams::default(); let result = Solver::from_design(design, Some(vec![1.0, 2.0]), ¶ms, None); let err = result diff --git a/crates/within/tests/properties.rs b/crates/within/tests/properties.rs index 26f0c83..60e079c 100644 --- a/crates/within/tests/properties.rs +++ b/crates/within/tests/properties.rs @@ -2,7 +2,7 @@ use ndarray::Array2; use proptest::prelude::*; use schwarz_precond::Operator; use within::observation::ArrayStore; -use within::{solve, FePreconditioner, Preconditioner, SolverParams, WeightedDesign}; +use within::{solve, Design, FePreconditioner, Preconditioner, SolverParams}; /// Generate a random fixed-effects problem as (categories Array2, y Vec). fn random_fe_problem_strategy() -> impl Strategy, Vec)> { @@ -73,7 +73,7 @@ proptest! { fn prop_solver_convergence((cats, _y) in random_fe_problem_strategy()) { // Create y = D * x_true so we know the answer let store = ArrayStore::new(cats.view()).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let n_dofs = design.n_dofs; let n_obs = design.n_rows; diff --git a/crates/within/tests/property_gaps.rs b/crates/within/tests/property_gaps.rs index 3087da4..2a089b7 100644 --- a/crates/within/tests/property_gaps.rs +++ b/crates/within/tests/property_gaps.rs @@ -1,7 +1,7 @@ use ndarray::Array2; use proptest::prelude::*; use within::observation::ArrayStore; -use within::{solve, Preconditioner, Solver, SolverParams, WeightedDesign}; +use within::{solve, Design, Preconditioner, Solver, SolverParams}; #[path = "common/orchestrate_helpers.rs"] mod common; @@ -105,7 +105,7 @@ proptest! { let precond = additive_precond(); // Build the design with a unit-solution RHS so the problem is feasible let store = ArrayStore::new(cats.view()).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let y_feasible: Vec = { let x_true = vec![1.0; design.n_dofs]; let mut y_out = vec![0.0; design.n_rows]; @@ -177,7 +177,7 @@ proptest! { let n_obs = y.len(); let n_factors = cats.ncols(); - // Compute factor offsets (same ordering as WeightedDesign) + // Compute factor offsets (same ordering as Design) let mut offsets = vec![0usize; n_factors]; for f in 1..n_factors { let n_levels_prev = *cats.column(f - 1).iter().max().unwrap() as usize + 1; @@ -208,7 +208,7 @@ proptest! { fn prop_single_factor_converges((cats, _y) in single_factor_strategy()) { // Build a consistent RHS: y = D * 1 so the system is exactly solvable. let store = ArrayStore::new(cats.view()).unwrap(); - let design = WeightedDesign::from_store(store).unwrap(); + let design = Design::from_store(store).unwrap(); let n_levels = design.n_dofs; let x_true = vec![1.0; n_levels]; let mut y_feasible = vec![0.0; design.n_rows]; From 132d22dd3332e9c25c791b9305e371a9b657da4b Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 11:17:40 +0200 Subject: [PATCH 3/6] strip operator methods from Design; move helpers into operator.rs Design is now pure data + layout: store, factors, n_rows, n_dofs, from_store, n_factors. The matvec helpers (gather_add, scatter_add, scatter_add_single_factor, level_from_column_or_store, ScatterStrategy, PAR_THRESHOLD, SCATTER_LOCAL_THRESHOLD, factor_columns) move into operator.rs as module-private free functions taking &Design. The DesignOperator::apply / apply_adjoint methods now call them directly. - Removed Design::matvec_d, rmatvec_dt, rmatvec_wdt, uid_weight (and the internal gather_add/scatter_add methods). The uid_weight read in cross_tab.rs inlines to weights.map_or(1.0, |w| w[uid]). - Solver::solve diagnostic computes D^T W v via apply_adjoint(W^{1/2} v) instead of the dedicated rmatvec_wdt method; results are bit-identical because D^T W = D^T W^{1/2} W^{1/2} algebraically. - Tests/examples/benches that called design.matvec_d(...) etc. now use DesignOperator::new(&design, None).apply(...) (and likewise for the adjoint). The internal tests in src/domain.rs and src/operator/tests.rs are pruned to remove redundant matvec coverage that DesignOperator already exercises. --- CHANGELOG.md | 22 +- crates/within/benches/fixest.rs | 4 +- crates/within/examples/solve_demo.rs | 6 +- crates/within/src/domain.rs | 387 +----------------- crates/within/src/operator.rs | 237 ++++++++++- .../within/src/operator/gramian/cross_tab.rs | 6 +- crates/within/src/solver.rs | 16 +- .../tests/common/orchestrate_helpers.rs | 6 +- crates/within/tests/domain.rs | 45 +- crates/within/tests/properties.rs | 5 +- crates/within/tests/property_gaps.rs | 10 +- 11 files changed, 334 insertions(+), 410 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce394af..45f38c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,18 @@ Modified LSMR is now the sole iterative solver, replacing CG and GMRES. type parameter; the `LocalSolveInvoker` trait, `DefaultLocalSolveInvoker`, and `with_strategy_and_invoker` constructor are removed. - LSMR vector kernels parallelized via Rayon. +- **BREAKING:** `ObservationStore` trait renamed to `Store`; + `WeightedDesign` to `Design`; `WeightedDesignOperator` to + `DesignOperator` (closes #28). +- **BREAKING:** Observation weights externalized from the store layer. + `FactorMajorStore::new` and `ArrayStore::new` drop their weights + argument. `Solver::from_design` / `from_design_with_preconditioner` and + `build_preconditioner` gain `Option>` / `Option<&[f64]>` + weights parameters. Python `solve` / `Solver` signatures unchanged. +- **BREAKING:** `Design` is now pure data + layout. The `matvec_d`, + `rmatvec_dt`, `rmatvec_wdt`, `gramian_diagonal`, and `uid_weight` + methods are removed — use `DesignOperator::new(&design, weights).apply` + / `apply_adjoint` instead. ### Removed @@ -41,9 +53,11 @@ Modified LSMR is now the sole iterative solver, replacing CG and GMRES. `SolverParams.max_refinements`, Python `CG`/`GMRES`/ `MultiplicativeSchwarz`, `MultiplicativeSchwarzPreconditioner`, `ResidualUpdater`, `OperatorResidualUpdater`, `IdentityOperator`). -- **BREAKING:** `Gramian`, `GramianOperator`, `DesignOperator`, - `build_schwarz`, `FeSchwarz`, and `WithinError::Overflow` removed from - the `within` public surface. LSMR uses `WeightedDesignOperator` directly. +- **BREAKING:** `Gramian`, `GramianOperator`, the previous bare + `DesignOperator`, `build_schwarz`, `FeSchwarz`, and + `WithinError::Overflow` removed from the `within` public surface. LSMR + uses the rectangular weighted design operator directly (see also #28, + which later renamed it to `DesignOperator`). - **BREAKING:** `schwarz_precond::solve::{cg, gmres}` and the `solve` module removed; use crate-root `lsmr`/`mlsmr`. `schwarz_precond::schwarz::{additive, multiplicative}` flattened into @@ -60,6 +74,8 @@ Modified LSMR is now the sole iterative solver, replacing CG and GMRES. variants moved onto `SolveError`. `PyFePreconditioner.apply` raises `RuntimeError` on local-solver failure instead of returning NaNs (closes #29). +- **BREAKING:** `ObservationWeights` enum removed; `Store::weight` and + `Store::is_unweighted` removed from the trait (closes #28). ## [0.1.0] - 2026-03-12 diff --git a/crates/within/benches/fixest.rs b/crates/within/benches/fixest.rs index 30403f7..2841b82 100644 --- a/crates/within/benches/fixest.rs +++ b/crates/within/benches/fixest.rs @@ -82,7 +82,9 @@ fn generate_fixest_like_case(case: Case, seed: u64) -> (Design } let mut y = vec![0.0; case.n_obs]; - design.matvec_d(&x_true, &mut y); + DesignOperator::new(&design, None) + .apply(&x_true, &mut y) + .expect("apply succeeds"); for yi in &mut y { *yi += 0.1 * rng.random_range(-1.0..1.0); } diff --git a/crates/within/examples/solve_demo.rs b/crates/within/examples/solve_demo.rs index 75f82e5..f6a3d14 100644 --- a/crates/within/examples/solve_demo.rs +++ b/crates/within/examples/solve_demo.rs @@ -19,7 +19,9 @@ fn main() { } // Build design to compute D * x_true. + use schwarz_precond::Operator as _; use within::observation::FactorMajorStore; + use within::operator::DesignOperator; use within::Design; let factor_levels = vec![categories.column(0).to_vec(), categories.column(1).to_vec()]; @@ -30,7 +32,9 @@ fn main() { let total_dofs = design.n_dofs; let x_true: Vec = (0..total_dofs).map(|j| (j % 7) as f64 - 3.0).collect(); let mut y = vec![0.0; n_obs]; - design.matvec_d(&x_true, &mut y); + DesignOperator::new(&design, None) + .apply(&x_true, &mut y) + .expect("apply succeeds"); // Add small deterministic perturbation so the system is not trivially exact. for (i, yi) in y.iter_mut().enumerate() { *yi += 0.01 * ((i * 7 + 3) % 13) as f64 - 0.06; diff --git a/crates/within/src/domain.rs b/crates/within/src/domain.rs index 3c64a25..c7d1c5d 100644 --- a/crates/within/src/domain.rs +++ b/crates/within/src/domain.rs @@ -1,14 +1,13 @@ -//! Domain layer: weighted design matrix and factor-pair subdomain construction. +//! Domain layer: design matrix metadata and factor-pair subdomain construction. //! //! This module sits between raw observation storage ([`crate::observation`]) and //! the linear-algebra operators ([`crate::operator`]). It answers two questions: //! -//! 1. **What does the design matrix look like?** — [`Design`] wraps an -//! [`Store`] with per-factor metadata ([`FactorMeta`]) and -//! provides the core matrix-vector products (`D·x`, `D^T·r`, `D^T·W·r`) -//! needed by every solver path. Observation weights are passed in as -//! `Option<&[f64]>` rather than owned, keeping the design itself purely -//! structural. +//! 1. **What does the design matrix look like?** — [`Design`] wraps a [`Store`] +//! with per-factor metadata ([`FactorMeta`]). It is *pure data + layout*: it +//! knows the number of rows, the number of DOFs, and how to recover factor +//! levels per observation. The matrix-vector products live next door in +//! [`crate::operator::DesignOperator`]. //! //! 2. **How is the problem decomposed into subdomains?** — The `factor_pairs` //! submodule builds one [`Subdomain`] per connected component of each factor @@ -45,7 +44,6 @@ //! gets a local solver, and the partition-of-unity weights ensure that //! overlapping DOFs (levels that appear in multiple factor pairs) are correctly //! scaled. See the `factor_pairs` submodule for details. -//! pub(crate) mod factor_pairs; @@ -74,27 +72,17 @@ impl std::fmt::Debug for Subdomain { } // =========================================================================== -// Weighted design matrix +// Design — categorical fixed-effects design (data + layout) // =========================================================================== -// `Design` is generic over `Store`. It stores per-factor -// metadata via `FactorMeta` and delegates observation data access to the -// pluggable store backend `S`. Observation weights are not owned here — they -// are passed in as `Option<&[f64]>` to the weight-sensitive matvec methods. -use std::sync::atomic::Ordering; - -use portable_atomic::AtomicF64; -use rayon::prelude::*; - use crate::observation::{FactorMeta, Store}; use crate::{WithinError, WithinResult}; -/// Weighted fixed-effects design matrix, generic over observation storage. +/// Fixed-effects design, generic over observation storage. /// /// `store` holds per-observation factor levels; `factors` holds per-factor -/// metadata (n_levels, offset). Observation weights are **not** stored — they -/// are passed in to methods that need them via `Option<&[f64]>` (where `None` -/// denotes unit weights). +/// metadata (n_levels, offset). The `Design` itself is pure data + layout — +/// matrix-vector products live in [`crate::operator::DesignOperator`]. pub struct Design { /// Observation storage backend (owns or borrows the raw factor levels). pub store: S, @@ -155,364 +143,9 @@ impl Design { }) } - /// Weight for observation `uid` given an optional weight slice. - /// - /// `None` ⇒ 1.0 (unit weights); `Some(w)` ⇒ `w[uid]`. - #[inline] - pub fn uid_weight(weights: Option<&[f64]>, uid: usize) -> f64 { - match weights { - None => 1.0, - Some(w) => w[uid], - } - } - /// Number of categorical factors in the design. #[inline] pub fn n_factors(&self) -> usize { self.factors.len() } - - /// Pre-compute factor column slices for all factors. - /// - /// Returns a vec where entry `q` is the store's contiguous column for factor `q`, - /// or `None` if the store doesn't support direct column access. - fn factor_columns(&self) -> Vec> { - self.factors - .iter() - .enumerate() - .map(|(q, _)| self.store.factor_column(q)) - .collect() - } -} - -// --------------------------------------------------------------------------- -// Design matrix operations (D·x, D^T·r, D^T·W·r) -// --------------------------------------------------------------------------- - -/// Minimum number of rows before scatter/gather loops are parallelized. -const PAR_THRESHOLD: usize = 10_000; - -/// Factor-level threshold for choosing between fold and atomic scatter-add. -/// -/// Factors with fewer than this many levels use thread-local fold/reduce -/// (O(n_levels * n_threads) memory). Larger factors use atomic CAS instead, -/// which has low contention when bins vastly outnumber threads. -/// 100K levels * 8 bytes * ~24 Rayon tasks ~ 19 MB — fits comfortably. -const SCATTER_LOCAL_THRESHOLD: usize = 100_000; - -/// Strategy for a single factor's scatter-add loop. -enum ScatterStrategy { - /// Plain sequential loop — used when n_rows is below `PAR_THRESHOLD`. - Sequential, - /// Parallel fold/reduce with thread-local accumulators — for small factors. - Fold, - /// Parallel atomic CAS — for large factors with low contention. - Atomic, -} - -#[inline] -fn level_from_column_or_store( - store: &S, - levels: Option<&[u32]>, - row: usize, - factor: usize, -) -> usize { - match levels { - Some(col) => col[row] as usize, - None => store.level(row, factor) as usize, - } -} - -/// Scatter-add for a single factor, dispatched by strategy. -/// -/// Accumulates `value_fn(i)` into `slice[level(i, q)]` for all rows, using the -/// requested parallelization strategy. The `atomic_buf` is reused across calls -/// to avoid repeated allocation in the `Atomic` path. -/// -/// All branches share the same per-row `(level, value)` computation via -/// `level_value`; only the accumulation strategy differs. -#[allow(clippy::too_many_arguments)] -fn scatter_add_single_factor( - slice: &mut [f64], - n_rows: usize, - n_levels: usize, - store: &S, - levels: Option<&[u32]>, - q: usize, - value_fn: &(impl Fn(usize) -> f64 + Sync), - strategy: ScatterStrategy, - atomic_buf: &mut Vec, -) { - #[inline(always)] - fn level_value( - store: &S, - levels: Option<&[u32]>, - q: usize, - value_fn: &impl Fn(usize) -> f64, - i: usize, - ) -> (usize, f64) { - let level = level_from_column_or_store(store, levels, i, q); - (level, value_fn(i)) - } - - match strategy { - ScatterStrategy::Sequential => { - for i in 0..n_rows { - let (level, val) = level_value(store, levels, q, value_fn, i); - slice[level] += val; - } - } - ScatterStrategy::Fold => { - let min_len = (n_rows / rayon::current_num_threads().max(1)).max(1024); - let result: Vec = (0..n_rows) - .into_par_iter() - .with_min_len(min_len) - .fold( - || vec![0.0f64; n_levels], - |mut acc, i| { - let (level, val) = level_value(store, levels, q, value_fn, i); - acc[level] += val; - acc - }, - ) - .reduce( - || vec![0.0f64; n_levels], - |mut a, b| { - for (ai, bi) in a.iter_mut().zip(b.iter()) { - *ai += *bi; - } - a - }, - ); - for (d, r) in slice.iter_mut().zip(result.iter()) { - *d += *r; - } - } - ScatterStrategy::Atomic => { - atomic_buf.clear(); - atomic_buf.extend(slice.iter().map(|&v| AtomicF64::new(v))); - (0..n_rows).into_par_iter().for_each(|i| { - let (level, val) = level_value(store, levels, q, value_fn, i); - atomic_buf[level].fetch_add(val, Ordering::Relaxed); - }); - for (d, a) in slice.iter_mut().zip(atomic_buf.iter()) { - *d = a.load(Ordering::Relaxed); - } - } - } -} - -impl Design { - /// Gather-add: `dst[i] += src[offset_q + level(i, q)]` for each factor `q` and row `i`. - /// - /// This is the core loop of `y = D·x`. Loop order is chosen based on the - /// store's preferred iteration pattern for cache locality. - /// - /// For large problems (n_rows > 10 000), rows are partitioned into chunks - /// and processed in parallel via Rayon `par_chunks_mut`. - #[inline] - fn gather_add(&self, src: &[f64], dst: &mut [f64]) { - const CHUNK_SIZE: usize = 4096; - let factor_columns = self.factor_columns(); - - if self.n_rows > PAR_THRESHOLD { - // Parallel path: each chunk processes its own row range. - // The inner loop iterates factors inside each chunk, which is optimal for the - // common case (2-3 factors) where all factor data fits in L1 cache. For many - // factors (10+) a layout with factors in the outer loop might help, but - // econometric models typically have 2-5 factors so this isn't worth optimizing. - dst.par_chunks_mut(CHUNK_SIZE) - .enumerate() - .for_each(|(chunk_idx, chunk)| { - let row_start = chunk_idx * CHUNK_SIZE; - for (q, f) in self.factors.iter().enumerate() { - let levels = factor_columns[q]; - for (local, dst_val) in chunk.iter_mut().enumerate() { - let i = row_start + local; - let level = level_from_column_or_store(&self.store, levels, i, q); - *dst_val += src[f.offset + level]; - } - } - }); - } else { - // Sequential factor-major: outer loop on factors, inner on observations. - for (q, f) in self.factors.iter().enumerate() { - let levels = factor_columns[q]; - for (i, dst_i) in dst.iter_mut().enumerate().take(self.n_rows) { - let level = level_from_column_or_store(&self.store, levels, i, q); - *dst_i += src[f.offset + level]; - } - } - } - } - - /// Scatter-add: `dst[offset_q + level(i, q)] += value_fn(i)` for each factor `q` and row `i`. - /// - /// This is the core loop of `x = D^T · r` (and weighted variant `D^T · W · r`). - /// The `value_fn` closure computes the per-row contribution: - /// - unweighted: `|i| r[i]` - /// - weighted: `|i| w[i] * r[i]` - /// - /// For large problems, each factor's row loop is parallelized: - /// - Small factors (< 100K levels): thread-local fold/reduce (avoids CAS contention) - /// - Large factors: atomic CAS scatter (low contention on millions of bins) - /// Factors are processed sequentially so each gets the full thread pool. - #[inline] - fn scatter_add(&self, dst: &mut [f64], value_fn: impl Fn(usize) -> f64 + Sync) { - let factor_columns = self.factor_columns(); - let parallel = self.n_rows > PAR_THRESHOLD; - let max_levels = self.factors.iter().map(|f| f.n_levels).max().unwrap_or(0); - let mut atomic_buf: Vec = Vec::with_capacity(max_levels); - - for (q, f) in self.factors.iter().enumerate() { - let slice = &mut dst[f.offset..f.offset + f.n_levels]; - let levels = factor_columns[q]; - let strategy = if !parallel { - ScatterStrategy::Sequential - } else if f.n_levels < SCATTER_LOCAL_THRESHOLD { - ScatterStrategy::Fold - } else { - ScatterStrategy::Atomic - }; - scatter_add_single_factor( - slice, - self.n_rows, - f.n_levels, - &self.store, - levels, - q, - &value_fn, - strategy, - &mut atomic_buf, - ); - } - } - - /// y = D·x (gather-add, no weights) - pub fn matvec_d(&self, x: &[f64], y: &mut [f64]) { - debug_assert_eq!(x.len(), self.n_dofs); - debug_assert_eq!(y.len(), self.n_rows); - y.fill(0.0); - self.gather_add(x, y); - } - - /// x = D^T·r (scatter-add, no weights) - pub fn rmatvec_dt(&self, r: &[f64], x: &mut [f64]) { - debug_assert_eq!(r.len(), self.n_rows); - debug_assert_eq!(x.len(), self.n_dofs); - x.fill(0.0); - self.scatter_add(x, |i| r[i]); - } - - /// x = D^T·W·r (weighted scatter-add) - /// - /// `weights = None` falls through to [`Self::rmatvec_dt`]. Otherwise the - /// per-row factor `w[i] * r[i]` is applied inside the scatter loop. The - /// branch lives outside the inner loop. - pub fn rmatvec_wdt(&self, weights: Option<&[f64]>, r: &[f64], x: &mut [f64]) { - debug_assert_eq!(r.len(), self.n_rows); - debug_assert_eq!(x.len(), self.n_dofs); - let Some(w) = weights else { - return self.rmatvec_dt(r, x); - }; - debug_assert_eq!(w.len(), self.n_rows); - x.fill(0.0); - self.scatter_add(x, |i| w[i] * r[i]); - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use crate::observation::FactorMajorStore; - - fn make_test_design() -> Design { - let categories = vec![vec![0, 1, 2, 0, 1], vec![0, 1, 2, 3, 0]]; - let store = FactorMajorStore::new(categories, 5).expect("valid factor-major store"); - Design::from_store(store).expect("valid test design") - } - - fn make_test_design_2x2() -> Design { - let categories = vec![vec![0, 1, 0, 1], vec![0, 0, 1, 1]]; - let store = - FactorMajorStore::new(categories, 4).expect("valid weighted factor-major store"); - Design::from_store(store).expect("valid weighted design") - } - - #[test] - fn test_construction() { - let dm = make_test_design(); - assert_eq!(dm.n_factors(), 2); - assert_eq!(dm.n_dofs, 7); - assert_eq!(dm.n_rows, 5); - assert_eq!(dm.factors[0].offset, 0); - assert_eq!(dm.factors[1].offset, 3); - let block_offsets: Vec = dm - .factors - .iter() - .map(|f| f.offset) - .chain(std::iter::once(dm.n_dofs)) - .collect(); - assert_eq!(block_offsets, vec![0, 3, 7]); - } - - #[test] - fn test_factor_meta() { - let dm = make_test_design(); - assert_eq!(dm.factors[0].n_levels, 3); - assert_eq!(dm.factors[1].n_levels, 4); - assert_eq!(dm.store.level(0, 0), 0); - assert_eq!(dm.store.level(1, 0), 1); - assert_eq!(dm.store.level(2, 0), 2); - assert_eq!(dm.store.level(3, 0), 0); - assert_eq!(dm.store.level(4, 0), 1); - assert_eq!(dm.store.level(0, 1), 0); - assert_eq!(dm.store.level(1, 1), 1); - assert_eq!(dm.store.level(4, 1), 0); - } - - #[test] - fn test_matvec_d() { - let dm = make_test_design(); - let x = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0]; - let mut y = vec![0.0; 5]; - dm.matvec_d(&x, &mut y); - assert_eq!(y, vec![11.0, 22.0, 33.0, 41.0, 12.0]); - } - - #[test] - fn test_rmatvec_dt() { - let dm = make_test_design(); - let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; - let mut x = vec![0.0; 7]; - dm.rmatvec_dt(&r, &mut x); - assert_eq!(x, vec![5.0, 7.0, 3.0, 6.0, 2.0, 3.0, 4.0]); - } - - #[test] - fn test_rmatvec_wdt_unweighted() { - let dm = make_test_design(); - let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; - let mut x_dt = vec![0.0; 7]; - let mut x_wdt = vec![0.0; 7]; - dm.rmatvec_dt(&r, &mut x_dt); - dm.rmatvec_wdt(None, &r, &mut x_wdt); - assert_eq!(x_dt, x_wdt); - } - - #[test] - fn test_rmatvec_wdt_weighted() { - let dm = make_test_design_2x2(); - let weights = [1.0, 2.0, 3.0, 4.0]; - let r = vec![1.0, 1.0, 1.0, 1.0]; - let mut x = vec![0.0; 4]; - dm.rmatvec_wdt(Some(&weights), &r, &mut x); - // factor 0: level 0 has obs 0(w=1)+2(w=3)=4, level 1 has obs 1(w=2)+3(w=4)=6 - // factor 1: level 0 has obs 0(w=1)+1(w=2)=3, level 1 has obs 2(w=3)+3(w=4)=7 - assert_eq!(x, vec![4.0, 6.0, 3.0, 7.0]); - } } diff --git a/crates/within/src/operator.rs b/crates/within/src/operator.rs index b5f8707..a43e959 100644 --- a/crates/within/src/operator.rs +++ b/crates/within/src/operator.rs @@ -43,8 +43,11 @@ mod tests; // DesignOperator — rectangular, W^{1/2}·D·x / D^T·W^{1/2}·x // --------------------------------------------------------------------------- +use std::sync::atomic::Ordering; use std::sync::Mutex; +use portable_atomic::AtomicF64; +use rayon::prelude::*; use schwarz_precond::Operator; use crate::domain::Design; @@ -66,7 +69,7 @@ pub struct DesignOperator<'a, S: Store> { } impl<'a, S: Store> DesignOperator<'a, S> { - /// Create from a weighted design matrix and optional observation weights. + /// Create from a design matrix and optional observation weights. /// /// `weights = None` selects the unweighted fast-path: `sqrt_weights` is /// `None` and `apply` / `apply_adjoint` skip the per-row scaling entirely. @@ -89,6 +92,7 @@ impl<'a, S: Store> DesignOperator<'a, S> { } } } + impl Operator for DesignOperator<'_, S> { fn nrows(&self) -> usize { self.design.n_rows @@ -99,8 +103,11 @@ impl Operator for DesignOperator<'_, S> { } fn apply(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { + debug_assert_eq!(x.len(), self.design.n_dofs); + debug_assert_eq!(y.len(), self.design.n_rows); // y = W^{1/2} (D x) - self.design.matvec_d(x, y); + y.fill(0.0); + gather_add(self.design, x, y); if let Some(sw) = &self.sqrt_weights { for (yi, &swi) in y.iter_mut().zip(sw) { *yi *= swi; @@ -110,17 +117,239 @@ impl Operator for DesignOperator<'_, S> { } fn apply_adjoint(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { + debug_assert_eq!(x.len(), self.design.n_rows); + debug_assert_eq!(y.len(), self.design.n_dofs); // y = D^T (W^{1/2} x) + y.fill(0.0); match &self.sqrt_weights { - None => self.design.rmatvec_dt(x, y), + None => scatter_add(self.design, y, |i| x[i]), Some(sw) => { let mut tmp = self.scratch.lock().unwrap(); for (ti, (&xi, &swi)) in tmp.iter_mut().zip(x.iter().zip(sw)) { *ti = swi * xi; } - self.design.rmatvec_dt(&tmp, y); + scatter_add(self.design, y, |i| tmp[i]); } } Ok(()) } } + +// --------------------------------------------------------------------------- +// gather/scatter helpers — implementation of DesignOperator's apply/apply_adjoint +// --------------------------------------------------------------------------- +// +// These free functions implement the per-row scatter/gather over a `Design`. +// They live in this module (not in `domain.rs`) because the design itself is +// pure data + layout; these helpers compute the linear map, which is an +// operator concern. + +/// Minimum number of rows before scatter/gather loops are parallelized. +const PAR_THRESHOLD: usize = 10_000; + +/// Factor-level threshold for choosing between fold and atomic scatter-add. +/// +/// Factors with fewer than this many levels use thread-local fold/reduce +/// (O(n_levels * n_threads) memory). Larger factors use atomic CAS instead, +/// which has low contention when bins vastly outnumber threads. +/// 100K levels * 8 bytes * ~24 Rayon tasks ~ 19 MB — fits comfortably. +const SCATTER_LOCAL_THRESHOLD: usize = 100_000; + +/// Strategy for a single factor's scatter-add loop. +enum ScatterStrategy { + /// Plain sequential loop — used when n_rows is below `PAR_THRESHOLD`. + Sequential, + /// Parallel fold/reduce with thread-local accumulators — for small factors. + Fold, + /// Parallel atomic CAS — for large factors with low contention. + Atomic, +} + +#[inline] +fn level_from_column_or_store( + store: &S, + levels: Option<&[u32]>, + row: usize, + factor: usize, +) -> usize { + match levels { + Some(col) => col[row] as usize, + None => store.level(row, factor) as usize, + } +} + +/// Pre-compute factor column slices for all factors of `design`. +fn factor_columns(design: &Design) -> Vec> { + design + .factors + .iter() + .enumerate() + .map(|(q, _)| design.store.factor_column(q)) + .collect() +} + +/// Gather-add: `dst[i] += src[offset_q + level(i, q)]` for each factor `q` and row `i`. +/// +/// This is the core loop of `y = D·x`. Loop order is chosen based on the +/// store's preferred iteration pattern for cache locality. +/// +/// For large problems (n_rows > 10 000), rows are partitioned into chunks +/// and processed in parallel via Rayon `par_chunks_mut`. +fn gather_add(design: &Design, src: &[f64], dst: &mut [f64]) { + const CHUNK_SIZE: usize = 4096; + let factor_columns = factor_columns(design); + + if design.n_rows > PAR_THRESHOLD { + // Parallel path: each chunk processes its own row range. + // The inner loop iterates factors inside each chunk, which is optimal for the + // common case (2-3 factors) where all factor data fits in L1 cache. For many + // factors (10+) a layout with factors in the outer loop might help, but + // econometric models typically have 2-5 factors so this isn't worth optimizing. + dst.par_chunks_mut(CHUNK_SIZE) + .enumerate() + .for_each(|(chunk_idx, chunk)| { + let row_start = chunk_idx * CHUNK_SIZE; + for (q, f) in design.factors.iter().enumerate() { + let levels = factor_columns[q]; + for (local, dst_val) in chunk.iter_mut().enumerate() { + let i = row_start + local; + let level = level_from_column_or_store(&design.store, levels, i, q); + *dst_val += src[f.offset + level]; + } + } + }); + } else { + // Sequential factor-major: outer loop on factors, inner on observations. + for (q, f) in design.factors.iter().enumerate() { + let levels = factor_columns[q]; + for (i, dst_i) in dst.iter_mut().enumerate().take(design.n_rows) { + let level = level_from_column_or_store(&design.store, levels, i, q); + *dst_i += src[f.offset + level]; + } + } + } +} + +/// Scatter-add: `dst[offset_q + level(i, q)] += value_fn(i)` for each factor `q` and row `i`. +/// +/// This is the core loop of `x = D^T · r` (and weighted variant `D^T · W · r`). +/// The `value_fn` closure computes the per-row contribution: +/// - unweighted: `|i| r[i]` +/// - weighted: `|i| w[i] * r[i]` +/// +/// For large problems, each factor's row loop is parallelized: +/// - Small factors (< 100K levels): thread-local fold/reduce (avoids CAS contention) +/// - Large factors: atomic CAS scatter (low contention on millions of bins) +/// Factors are processed sequentially so each gets the full thread pool. +fn scatter_add( + design: &Design, + dst: &mut [f64], + value_fn: impl Fn(usize) -> f64 + Sync, +) { + let factor_columns = factor_columns(design); + let parallel = design.n_rows > PAR_THRESHOLD; + let max_levels = design.factors.iter().map(|f| f.n_levels).max().unwrap_or(0); + let mut atomic_buf: Vec = Vec::with_capacity(max_levels); + + for (q, f) in design.factors.iter().enumerate() { + let slice = &mut dst[f.offset..f.offset + f.n_levels]; + let levels = factor_columns[q]; + let strategy = if !parallel { + ScatterStrategy::Sequential + } else if f.n_levels < SCATTER_LOCAL_THRESHOLD { + ScatterStrategy::Fold + } else { + ScatterStrategy::Atomic + }; + scatter_add_single_factor( + slice, + design.n_rows, + f.n_levels, + &design.store, + levels, + q, + &value_fn, + strategy, + &mut atomic_buf, + ); + } +} + +/// Scatter-add for a single factor, dispatched by strategy. +/// +/// Accumulates `value_fn(i)` into `slice[level(i, q)]` for all rows, using the +/// requested parallelization strategy. The `atomic_buf` is reused across calls +/// to avoid repeated allocation in the `Atomic` path. +/// +/// All branches share the same per-row `(level, value)` computation via +/// `level_value`; only the accumulation strategy differs. +#[allow(clippy::too_many_arguments)] +fn scatter_add_single_factor( + slice: &mut [f64], + n_rows: usize, + n_levels: usize, + store: &S, + levels: Option<&[u32]>, + q: usize, + value_fn: &(impl Fn(usize) -> f64 + Sync), + strategy: ScatterStrategy, + atomic_buf: &mut Vec, +) { + #[inline(always)] + fn level_value( + store: &S, + levels: Option<&[u32]>, + q: usize, + value_fn: &impl Fn(usize) -> f64, + i: usize, + ) -> (usize, f64) { + let level = level_from_column_or_store(store, levels, i, q); + (level, value_fn(i)) + } + + match strategy { + ScatterStrategy::Sequential => { + for i in 0..n_rows { + let (level, val) = level_value(store, levels, q, value_fn, i); + slice[level] += val; + } + } + ScatterStrategy::Fold => { + let min_len = (n_rows / rayon::current_num_threads().max(1)).max(1024); + let result: Vec = (0..n_rows) + .into_par_iter() + .with_min_len(min_len) + .fold( + || vec![0.0f64; n_levels], + |mut acc, i| { + let (level, val) = level_value(store, levels, q, value_fn, i); + acc[level] += val; + acc + }, + ) + .reduce( + || vec![0.0f64; n_levels], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }, + ); + for (d, r) in slice.iter_mut().zip(result.iter()) { + *d += *r; + } + } + ScatterStrategy::Atomic => { + atomic_buf.clear(); + atomic_buf.extend(slice.iter().map(|&v| AtomicF64::new(v))); + (0..n_rows).into_par_iter().for_each(|i| { + let (level, val) = level_value(store, levels, q, value_fn, i); + atomic_buf[level].fetch_add(val, Ordering::Relaxed); + }); + for (d, a) in slice.iter_mut().zip(atomic_buf.iter()) { + *d = a.load(Ordering::Relaxed); + } + } + } +} diff --git a/crates/within/src/operator/gramian/cross_tab.rs b/crates/within/src/operator/gramian/cross_tab.rs index 531ecf4..12c55d6 100644 --- a/crates/within/src/operator/gramian/cross_tab.rs +++ b/crates/within/src/operator/gramian/cross_tab.rs @@ -444,7 +444,7 @@ fn accumulate_dense_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = Design::::uid_weight(weights, uid); + let w = weights.map_or(1.0, |w| w[uid]); debug_assert!((cj as usize) < n_q && (ck as usize) < n_r); diag_q[cj as usize] += w; diag_r[ck as usize] += w; @@ -485,7 +485,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = Design::::uid_weight(weights, uid); + let w = weights.map_or(1.0, |w| w[uid]); diag_q[cj as usize] += w; diag_r[ck as usize] += w; row_counts[cj as usize] += 1; @@ -510,7 +510,7 @@ fn accumulate_sparse_cross_block( if cj == u32::MAX || ck == u32::MAX { continue; } - let w = Design::::uid_weight(weights, uid); + let w = weights.map_or(1.0, |w| w[uid]); let pos = cursor[cj as usize] as usize; bucket_cols[pos] = ck; bucket_vals[pos] = w; diff --git a/crates/within/src/solver.rs b/crates/within/src/solver.rs index e00c91f..44e56a3 100644 --- a/crates/within/src/solver.rs +++ b/crates/within/src/solver.rs @@ -36,7 +36,7 @@ use std::time::Instant; use ndarray::ArrayView2; use rayon::prelude::*; -use schwarz_precond::{lsmr, mlsmr}; +use schwarz_precond::{lsmr, mlsmr, Operator as _}; use crate::operator::DesignOperator; @@ -129,18 +129,24 @@ impl Solver { let time_solve = t_solve_start.elapsed().as_secs_f64(); + // demeaned = y - D x. The bare unweighted `D x` matvec uses an + // unweighted DesignOperator over the same design. + let bare_op = DesignOperator::new(&self.design, None); let mut demeaned = vec![0.0; self.design.n_rows]; - self.design.matvec_d(&r.x, &mut demeaned); + bare_op.apply(&r.x, &mut demeaned)?; for (d, &yi) in demeaned.iter_mut().zip(y.iter()) { *d = yi - *d; } - let w_ref = self.weights.as_deref(); + // Relative normal-equation residual: ||D^T W (y - Dx)|| / ||D^T W y||. + // Compute D^T W v as rect_op.apply_adjoint(W^{1/2} v): apply_adjoint + // delivers D^T W^{1/2} (·), so feeding W^{1/2} v gives D^T W v. let mut rhs = vec![0.0; self.design.n_dofs]; - self.design.rmatvec_wdt(w_ref, y, &mut rhs); + rect_op.apply_adjoint(&b, &mut rhs)?; let rhs_norm = norm(&rhs).max(1e-15); + let weighted_demeaned = rect_op.weighted_rhs(&demeaned); let mut residual_dof = vec![0.0; self.design.n_dofs]; - self.design.rmatvec_wdt(w_ref, &demeaned, &mut residual_dof); + rect_op.apply_adjoint(&weighted_demeaned, &mut residual_dof)?; let final_residual = norm(&residual_dof) / rhs_norm; Ok(SolveResult { diff --git a/crates/within/tests/common/orchestrate_helpers.rs b/crates/within/tests/common/orchestrate_helpers.rs index ada62e0..102f23a 100644 --- a/crates/within/tests/common/orchestrate_helpers.rs +++ b/crates/within/tests/common/orchestrate_helpers.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +use schwarz_precond::Operator as _; +use within::operator::DesignOperator; use within::{Design, FactorMajorStore, SolveResult}; pub fn make_test_design() -> Design { @@ -16,7 +18,9 @@ pub fn make_design(categories: Vec>) -> within::WithinResult) -> Vec { let x_true = vec![1.0; design.n_dofs]; let mut y = vec![0.0; design.n_rows]; - design.matvec_d(&x_true, &mut y); + DesignOperator::new(design, None) + .apply(&x_true, &mut y) + .expect("apply succeeds"); y } diff --git a/crates/within/tests/domain.rs b/crates/within/tests/domain.rs index 0e970e1..d9a36e3 100644 --- a/crates/within/tests/domain.rs +++ b/crates/within/tests/domain.rs @@ -3,9 +3,30 @@ //! exercise partition-of-unity weights and disconnected bipartite structure. use proptest::prelude::*; +use schwarz_precond::Operator as _; use within::observation::FactorMajorStore; +use within::operator::DesignOperator; use within::Design; +fn apply_d(dm: &Design, x: &[f64], y: &mut [f64]) { + DesignOperator::new(dm, None) + .apply(x, y) + .expect("apply succeeds"); +} + +fn apply_dt(dm: &Design, x: &[f64], y: &mut [f64]) { + DesignOperator::new(dm, None) + .apply_adjoint(x, y) + .expect("apply_adjoint succeeds"); +} + +fn apply_wdt(dm: &Design, weights: Option<&[f64]>, x: &[f64], y: &mut [f64]) { + // D^T W x = apply_adjoint(W^{1/2} x) for op = W^{1/2} D. + let op = DesignOperator::new(dm, weights); + let wx = op.weighted_rhs(x); + op.apply_adjoint(&wx, y).expect("apply_adjoint succeeds"); +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- @@ -46,10 +67,10 @@ fn test_large_design_adjoint_property_matvec_d_rmatvec_dt() { let r: Vec = (0..n_rows).map(|i| (i as f64 * 0.23 + 2.0).cos()).collect(); let mut dx = vec![0.0f64; n_rows]; - dm.matvec_d(&x, &mut dx); + apply_d(&dm, &x, &mut dx); let mut dtr = vec![0.0f64; n_dofs]; - dm.rmatvec_dt(&r, &mut dtr); + apply_dt(&dm, &r, &mut dtr); let lhs = dot(&dx, &r); let rhs = dot(&x, &dtr); @@ -72,7 +93,7 @@ fn test_large_design_matvec_correctness() { let mut ej = vec![0.0f64; n_dofs]; ej[0] = 1.0; let mut y = vec![0.0f64; n_rows]; - dm.matvec_d(&ej, &mut y); + apply_d(&dm, &ej, &mut y); for (i, &yi) in y.iter().enumerate() { let expected = if i % 50 == 0 { 1.0 } else { 0.0 }; @@ -92,7 +113,7 @@ fn test_large_design_rmatvec_dt_correctness() { let ones = vec![1.0f64; n_rows]; let mut x = vec![0.0f64; n_dofs]; - dm.rmatvec_dt(&ones, &mut x); + apply_dt(&dm, &ones, &mut x); // Each factor has 50 levels, 15,000 obs cycling → each level appears 300 times. let expected_count = (n_rows / 50) as f64; @@ -146,7 +167,7 @@ proptest! { // let mut dx = vec![0.0f64; n_rows]; - dm.matvec_d(&x, &mut dx); + apply_d(&dm, &x, &mut dx); let lhs: f64 = dx .iter() .zip(r.iter()) @@ -156,7 +177,7 @@ proptest! { // let mut wdtr = vec![0.0f64; n_dofs]; - dm.rmatvec_wdt(Some(&weights), &r, &mut wdtr); + apply_wdt(&dm, Some(&weights), &r, &mut wdtr); let rhs: f64 = x.iter().zip(wdtr.iter()).map(|(xi, wi)| xi * wi).sum(); prop_assert!( @@ -195,7 +216,7 @@ fn test_three_factor_design_solve_converges() { // Build y = D·1 so the true normal-equation solution is 1. let x_true = vec![1.0f64; dm.n_dofs]; let mut y = vec![0.0f64; dm.n_rows]; - dm.matvec_d(&x_true, &mut y); + apply_d(&dm, &x_true, &mut y); // Use ndarray array2 as required by the solve() API. let n_factors = 3; @@ -256,7 +277,7 @@ fn test_disconnected_design_larger_converges() { let x_true = vec![1.0f64; dm.n_dofs]; let mut y = vec![0.0f64; dm.n_rows]; - dm.matvec_d(&x_true, &mut y); + apply_d(&dm, &x_true, &mut y); let params = SolverParams { tol: 1e-8, @@ -335,10 +356,10 @@ fn test_single_factor_design_adjoint_property() { let r: Vec = vec![0.5, 1.5, -0.5, 2.0, -1.0]; let mut dx = vec![0.0f64; n_rows]; - dm.matvec_d(&x, &mut dx); + apply_d(&dm, &x, &mut dx); let mut dtr = vec![0.0f64; n_dofs]; - dm.rmatvec_dt(&r, &mut dtr); + apply_dt(&dm, &r, &mut dtr); let lhs = dot(&dx, &r); let rhs = dot(&x, &dtr); @@ -392,7 +413,7 @@ fn test_single_factor_matvec_d_values() { let x = vec![10.0, 20.0, 30.0]; let mut y = vec![0.0f64; 5]; - dm.matvec_d(&x, &mut y); + apply_d(&dm, &x, &mut y); assert_eq!(y, vec![10.0, 20.0, 30.0, 10.0, 20.0]); } @@ -405,6 +426,6 @@ fn test_single_factor_rmatvec_dt_values() { let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let mut x = vec![0.0f64; 3]; - dm.rmatvec_dt(&r, &mut x); + apply_dt(&dm, &r, &mut x); assert_eq!(x, vec![5.0, 7.0, 3.0]); } diff --git a/crates/within/tests/properties.rs b/crates/within/tests/properties.rs index 60e079c..a69be55 100644 --- a/crates/within/tests/properties.rs +++ b/crates/within/tests/properties.rs @@ -2,6 +2,7 @@ use ndarray::Array2; use proptest::prelude::*; use schwarz_precond::Operator; use within::observation::ArrayStore; +use within::operator::DesignOperator; use within::{solve, Design, FePreconditioner, Preconditioner, SolverParams}; /// Generate a random fixed-effects problem as (categories Array2, y Vec). @@ -79,7 +80,9 @@ proptest! { let x_true: Vec = (0..n_dofs).map(|i| (i as f64 * 0.4).sin()).collect(); let mut y = vec![0.0; n_obs]; - design.matvec_d(&x_true, &mut y); + DesignOperator::new(&design, None) + .apply(&x_true, &mut y) + .expect("apply succeeds"); // Use slightly relaxed tolerance — randomly generated problems can be // borderline at 1e-8 (e.g. residual 1.02e-8 after 13 iters). diff --git a/crates/within/tests/property_gaps.rs b/crates/within/tests/property_gaps.rs index 2a089b7..188826d 100644 --- a/crates/within/tests/property_gaps.rs +++ b/crates/within/tests/property_gaps.rs @@ -1,6 +1,8 @@ use ndarray::Array2; use proptest::prelude::*; +use schwarz_precond::Operator as _; use within::observation::ArrayStore; +use within::operator::DesignOperator; use within::{solve, Design, Preconditioner, Solver, SolverParams}; #[path = "common/orchestrate_helpers.rs"] @@ -109,7 +111,9 @@ proptest! { let y_feasible: Vec = { let x_true = vec![1.0; design.n_dofs]; let mut y_out = vec![0.0; design.n_rows]; - design.matvec_d(&x_true, &mut y_out); + DesignOperator::new(&design, None) + .apply(&x_true, &mut y_out) + .expect("apply succeeds"); y_out }; // Use y_feasible so convergence is guaranteed on a consistent system @@ -212,7 +216,9 @@ proptest! { let n_levels = design.n_dofs; let x_true = vec![1.0; n_levels]; let mut y_feasible = vec![0.0; design.n_rows]; - design.matvec_d(&x_true, &mut y_feasible); + DesignOperator::new(&design, None) + .apply(&x_true, &mut y_feasible) + .expect("apply succeeds"); // No preconditioner. let params = SolverParams { From 4f1410fd6f44f7ceb1bdb013454c2b12038fdb7d Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 11:43:43 +0200 Subject: [PATCH 4/6] DesignOperator: drop scratch Mutex, fuse weighted finalize into last gather Adopt the cleaner apply/apply_adjoint pattern explored on the closed refactor/store-design-operator branch: - Replace gather_add + scatter_add with gather_apply + scatter_apply, both taking a closure parameter. gather_apply folds the optional W^{1/2} multiply into the LAST factor's pass, so apply does exactly Q sweeps over dst regardless of weights (no trailing scale loop). scatter_apply computes sw[i] * x[i] inline through a closure, so apply_adjoint allocates no per-call scratch and DesignOperator no longer holds a Mutex>. - Add length assertion in DesignOperator::new: panics on mismatched weights length. Solver entry points keep their fallible validate_weights check, so callers going through Solver / solve() never trigger the panic. - ScatterStrategy::pick replaces the ad-hoc match in scatter_add. - Rename a few stale test/bench labels that still said matvec_d / rmatvec_dt: test_matvec_d -> test_apply_unweighted_values; test_rmatvec_dt -> test_apply_adjoint_unweighted_values; test_large_design_adjoint_property_matvec_d_rmatvec_dt -> drop the method-name suffix; matvec_weighted_design bench group -> design_operator_apply. --- crates/within/benches/fixest.rs | 2 +- crates/within/src/operator.rs | 398 ++++++++++++++-------------- crates/within/src/operator/tests.rs | 4 +- crates/within/tests/domain.rs | 4 +- 4 files changed, 201 insertions(+), 207 deletions(-) diff --git a/crates/within/benches/fixest.rs b/crates/within/benches/fixest.rs index 2841b82..dbeb9e7 100644 --- a/crates/within/benches/fixest.rs +++ b/crates/within/benches/fixest.rs @@ -273,7 +273,7 @@ fn matvec_cases() -> [Case; 4] { } fn bench_matvec(c: &mut Criterion) { - let mut group = configure_group(c, "matvec_weighted_design", 50, 200); + let mut group = configure_group(c, "design_operator_apply", 50, 200); for case in matvec_cases() { let label = case.label(); let (design, _y) = generate_fixest_like_case(case, 42); diff --git a/crates/within/src/operator.rs b/crates/within/src/operator.rs index a43e959..cf0b322 100644 --- a/crates/within/src/operator.rs +++ b/crates/within/src/operator.rs @@ -12,7 +12,7 @@ //! //! | Operator | Type | Description | //! |---|---|---| -//! | **sqrt(W) D** | [`DesignOperator`] | Rectangular operator used by LSMR. Implements `sqrt(W) D x` and `D^T sqrt(W) x` via gather/scatter on the observation store | +//! | **D** / **W^{1/2} D** | [`DesignOperator`] | Rectangular `D x` (or `W^{1/2} D x`); pass weights to [`DesignOperator::new`] for the weighted variant | //! //! # Submodules //! @@ -39,12 +39,7 @@ pub(crate) mod schwarz; #[cfg(test)] mod tests; -// --------------------------------------------------------------------------- -// DesignOperator — rectangular, W^{1/2}·D·x / D^T·W^{1/2}·x -// --------------------------------------------------------------------------- - use std::sync::atomic::Ordering; -use std::sync::Mutex; use portable_atomic::AtomicF64; use rayon::prelude::*; @@ -53,96 +48,9 @@ use schwarz_precond::Operator; use crate::domain::Design; use crate::observation::Store; -/// Weighted rectangular design operator: `A = W^{1/2} D`. -/// -/// `apply` = `W^{1/2} D x` (observation space), `apply_adjoint` = `D^T W^{1/2} x` (DOF space). -/// For unweighted designs, delegates directly to `D x` / `D^T x` with no extra work. -/// -/// The normal equations of this operator give `A^T A = D^T W D = G` (the Gramian), -/// so the existing Schwarz preconditioner approximating `G^{-1}` can be used directly. -pub struct DesignOperator<'a, S: Store> { - design: &'a Design, - /// Pre-computed `sqrt(w_i)` per observation. `None` when unweighted. - sqrt_weights: Option>, - /// Scratch for the adjoint path: stores `sqrt(w_i) * u_i`. - scratch: Mutex>, -} - -impl<'a, S: Store> DesignOperator<'a, S> { - /// Create from a design matrix and optional observation weights. - /// - /// `weights = None` selects the unweighted fast-path: `sqrt_weights` is - /// `None` and `apply` / `apply_adjoint` skip the per-row scaling entirely. - pub fn new(design: &'a Design, weights: Option<&[f64]>) -> Self { - let sqrt_weights = weights.map(|w| w.iter().map(|wi| wi.sqrt()).collect::>()); - Self { - scratch: Mutex::new(vec![0.0; design.n_rows]), - design, - sqrt_weights, - } - } - - /// Compute the observation-space RHS `b = W^{1/2} y`. - /// - /// For unweighted designs, returns a copy of `y`. - pub fn weighted_rhs(&self, y: &[f64]) -> Vec { - match &self.sqrt_weights { - None => y.to_vec(), - Some(sw) => y.iter().zip(sw).map(|(&yi, &swi)| swi * yi).collect(), - } - } -} - -impl Operator for DesignOperator<'_, S> { - fn nrows(&self) -> usize { - self.design.n_rows - } - - fn ncols(&self) -> usize { - self.design.n_dofs - } - - fn apply(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { - debug_assert_eq!(x.len(), self.design.n_dofs); - debug_assert_eq!(y.len(), self.design.n_rows); - // y = W^{1/2} (D x) - y.fill(0.0); - gather_add(self.design, x, y); - if let Some(sw) = &self.sqrt_weights { - for (yi, &swi) in y.iter_mut().zip(sw) { - *yi *= swi; - } - } - Ok(()) - } - - fn apply_adjoint(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { - debug_assert_eq!(x.len(), self.design.n_rows); - debug_assert_eq!(y.len(), self.design.n_dofs); - // y = D^T (W^{1/2} x) - y.fill(0.0); - match &self.sqrt_weights { - None => scatter_add(self.design, y, |i| x[i]), - Some(sw) => { - let mut tmp = self.scratch.lock().unwrap(); - for (ti, (&xi, &swi)) in tmp.iter_mut().zip(x.iter().zip(sw)) { - *ti = swi * xi; - } - scatter_add(self.design, y, |i| tmp[i]); - } - } - Ok(()) - } -} - -// --------------------------------------------------------------------------- -// gather/scatter helpers — implementation of DesignOperator's apply/apply_adjoint -// --------------------------------------------------------------------------- -// -// These free functions implement the per-row scatter/gather over a `Design`. -// They live in this module (not in `domain.rs`) because the design itself is -// pure data + layout; these helpers compute the linear map, which is an -// operator concern. +// =========================================================================== +// Iteration kernels — module-private, shared between apply / apply_adjoint +// =========================================================================== /// Minimum number of rows before scatter/gather loops are parallelized. const PAR_THRESHOLD: usize = 10_000; @@ -152,7 +60,6 @@ const PAR_THRESHOLD: usize = 10_000; /// Factors with fewer than this many levels use thread-local fold/reduce /// (O(n_levels * n_threads) memory). Larger factors use atomic CAS instead, /// which has low contention when bins vastly outnumber threads. -/// 100K levels * 8 bytes * ~24 Rayon tasks ~ 19 MB — fits comfortably. const SCATTER_LOCAL_THRESHOLD: usize = 100_000; /// Strategy for a single factor's scatter-add loop. @@ -165,87 +72,105 @@ enum ScatterStrategy { Atomic, } +impl ScatterStrategy { + /// Pick the scatter strategy for one factor. + fn pick(parallel: bool, n_levels: usize) -> Self { + match (parallel, n_levels < SCATTER_LOCAL_THRESHOLD) { + (false, _) => ScatterStrategy::Sequential, + (true, true) => ScatterStrategy::Fold, + (true, false) => ScatterStrategy::Atomic, + } + } +} + +/// Resolve the level for row `i` in factor `q`. +/// +/// `levels` is the optional fast-path column (a contiguous `&[u32]` view of the +/// factor's levels); when `None`, fall back to the store's virtual lookup. +/// Hoisted out of inner loops so the compiler keeps the row body branch-free. #[inline] -fn level_from_column_or_store( - store: &S, - levels: Option<&[u32]>, - row: usize, - factor: usize, -) -> usize { +fn level_at(store: &S, levels: Option<&[u32]>, i: usize, q: usize) -> usize { match levels { - Some(col) => col[row] as usize, - None => store.level(row, factor) as usize, + Some(col) => col[i] as usize, + None => store.level(i, q) as usize, } } -/// Pre-compute factor column slices for all factors of `design`. +/// Pre-compute the factor-column fast-path slices for all factors of `design`. fn factor_columns(design: &Design) -> Vec> { - design - .factors - .iter() - .enumerate() - .map(|(q, _)| design.store.factor_column(q)) + (0..design.factors.len()) + .map(|q| design.store.factor_column(q)) .collect() } -/// Gather-add: `dst[i] += src[offset_q + level(i, q)]` for each factor `q` and row `i`. -/// -/// This is the core loop of `y = D·x`. Loop order is chosen based on the -/// store's preferred iteration pattern for cache locality. +/// Gather-apply: `dst[i] = finalize(i, Σ_q src[off_q + level(i, q)])`. /// -/// For large problems (n_rows > 10 000), rows are partitioned into chunks -/// and processed in parallel via Rayon `par_chunks_mut`. -fn gather_add(design: &Design, src: &[f64], dst: &mut [f64]) { - const CHUNK_SIZE: usize = 4096; +/// `finalize` is folded into the LAST factor's pass — exactly Q sweeps over +/// `dst`, no trailing scale loop. The identity finalize (`|_, s| s`) recovers +/// the unweighted gather. +fn gather_apply(design: &Design, src: &[f64], dst: &mut [f64], finalize: F) +where + S: Store, + F: Fn(usize, f64) -> f64 + Sync, +{ + debug_assert_eq!(src.len(), design.n_dofs); + debug_assert_eq!(dst.len(), design.n_rows); + let factors = &design.factors; + if factors.is_empty() { + // Q=0 guard — no factors means dst[i] = finalize(i, 0.0). + for (i, d) in dst.iter_mut().enumerate() { + *d = finalize(i, 0.0); + } + return; + } + dst.fill(0.0); let factor_columns = factor_columns(design); + let store = &design.store; + let last = factors.len() - 1; + + let kernel = |chunk: &mut [f64], row_start: usize| { + // Accumulate factors 0..last + for q in 0..last { + let f = &factors[q]; + let levels = factor_columns[q]; + for (local, dst_val) in chunk.iter_mut().enumerate() { + let i = row_start + local; + *dst_val += src[f.offset + level_at(store, levels, i, q)]; + } + } + // Last factor: accumulate AND finalize, single store per row. + // Q=1 is well-defined: this is the only loop that runs. + let f = &factors[last]; + let levels = factor_columns[last]; + for (local, dst_val) in chunk.iter_mut().enumerate() { + let i = row_start + local; + let s = *dst_val + src[f.offset + level_at(store, levels, i, last)]; + *dst_val = finalize(i, s); + } + }; if design.n_rows > PAR_THRESHOLD { - // Parallel path: each chunk processes its own row range. - // The inner loop iterates factors inside each chunk, which is optimal for the - // common case (2-3 factors) where all factor data fits in L1 cache. For many - // factors (10+) a layout with factors in the outer loop might help, but - // econometric models typically have 2-5 factors so this isn't worth optimizing. + const CHUNK_SIZE: usize = 4096; dst.par_chunks_mut(CHUNK_SIZE) .enumerate() - .for_each(|(chunk_idx, chunk)| { - let row_start = chunk_idx * CHUNK_SIZE; - for (q, f) in design.factors.iter().enumerate() { - let levels = factor_columns[q]; - for (local, dst_val) in chunk.iter_mut().enumerate() { - let i = row_start + local; - let level = level_from_column_or_store(&design.store, levels, i, q); - *dst_val += src[f.offset + level]; - } - } - }); + .for_each(|(chunk_idx, chunk)| kernel(chunk, chunk_idx * CHUNK_SIZE)); } else { - // Sequential factor-major: outer loop on factors, inner on observations. - for (q, f) in design.factors.iter().enumerate() { - let levels = factor_columns[q]; - for (i, dst_i) in dst.iter_mut().enumerate().take(design.n_rows) { - let level = level_from_column_or_store(&design.store, levels, i, q); - *dst_i += src[f.offset + level]; - } - } + kernel(dst, 0); } } -/// Scatter-add: `dst[offset_q + level(i, q)] += value_fn(i)` for each factor `q` and row `i`. -/// -/// This is the core loop of `x = D^T · r` (and weighted variant `D^T · W · r`). -/// The `value_fn` closure computes the per-row contribution: -/// - unweighted: `|i| r[i]` -/// - weighted: `|i| w[i] * r[i]` +/// Scatter-apply: `dst[off_q + level(i, q)] += value_fn(i)` for each factor q, row i. /// -/// For large problems, each factor's row loop is parallelized: -/// - Small factors (< 100K levels): thread-local fold/reduce (avoids CAS contention) -/// - Large factors: atomic CAS scatter (low contention on millions of bins) -/// Factors are processed sequentially so each gets the full thread pool. -fn scatter_add( - design: &Design, - dst: &mut [f64], - value_fn: impl Fn(usize) -> f64 + Sync, -) { +/// Caller is responsible for any leading `dst.fill(0.0)`. For large problems, +/// each factor's row loop is parallelized: +/// - Small factors (< 100K levels): thread-local fold/reduce +/// - Large factors: atomic CAS scatter +fn scatter_apply(design: &Design, dst: &mut [f64], value_fn: F) +where + S: Store, + F: Fn(usize) -> f64 + Sync, +{ + debug_assert_eq!(dst.len(), design.n_dofs); let factor_columns = factor_columns(design); let parallel = design.n_rows > PAR_THRESHOLD; let max_levels = design.factors.iter().map(|f| f.n_levels).max().unwrap_or(0); @@ -254,17 +179,10 @@ fn scatter_add( for (q, f) in design.factors.iter().enumerate() { let slice = &mut dst[f.offset..f.offset + f.n_levels]; let levels = factor_columns[q]; - let strategy = if !parallel { - ScatterStrategy::Sequential - } else if f.n_levels < SCATTER_LOCAL_THRESHOLD { - ScatterStrategy::Fold - } else { - ScatterStrategy::Atomic - }; + let strategy = ScatterStrategy::pick(parallel, f.n_levels); scatter_add_single_factor( slice, design.n_rows, - f.n_levels, &design.store, levels, q, @@ -280,14 +198,10 @@ fn scatter_add( /// Accumulates `value_fn(i)` into `slice[level(i, q)]` for all rows, using the /// requested parallelization strategy. The `atomic_buf` is reused across calls /// to avoid repeated allocation in the `Atomic` path. -/// -/// All branches share the same per-row `(level, value)` computation via -/// `level_value`; only the accumulation strategy differs. #[allow(clippy::too_many_arguments)] fn scatter_add_single_factor( slice: &mut [f64], n_rows: usize, - n_levels: usize, store: &S, levels: Option<&[u32]>, q: usize, @@ -295,47 +209,38 @@ fn scatter_add_single_factor( strategy: ScatterStrategy, atomic_buf: &mut Vec, ) { - #[inline(always)] - fn level_value( - store: &S, - levels: Option<&[u32]>, - q: usize, - value_fn: &impl Fn(usize) -> f64, - i: usize, - ) -> (usize, f64) { - let level = level_from_column_or_store(store, levels, i, q); - (level, value_fn(i)) - } + let n_levels = slice.len(); + let lvl = |i: usize| level_at(store, levels, i, q); match strategy { ScatterStrategy::Sequential => { for i in 0..n_rows { - let (level, val) = level_value(store, levels, q, value_fn, i); - slice[level] += val; + slice[lvl(i)] += value_fn(i); } } ScatterStrategy::Fold => { let min_len = (n_rows / rayon::current_num_threads().max(1)).max(1024); + + let identity = || vec![0.0f64; n_levels]; + + let fold = |mut acc: Vec, i| { + acc[lvl(i)] += value_fn(i); + acc + }; + + let reduction = |mut a: Vec, b: Vec| { + for (ai, bi) in a.iter_mut().zip(b.iter()) { + *ai += *bi; + } + a + }; + let result: Vec = (0..n_rows) .into_par_iter() .with_min_len(min_len) - .fold( - || vec![0.0f64; n_levels], - |mut acc, i| { - let (level, val) = level_value(store, levels, q, value_fn, i); - acc[level] += val; - acc - }, - ) - .reduce( - || vec![0.0f64; n_levels], - |mut a, b| { - for (ai, bi) in a.iter_mut().zip(b.iter()) { - *ai += *bi; - } - a - }, - ); + .fold(identity, fold) + .reduce(identity, reduction); + for (d, r) in slice.iter_mut().zip(result.iter()) { *d += *r; } @@ -344,8 +249,7 @@ fn scatter_add_single_factor( atomic_buf.clear(); atomic_buf.extend(slice.iter().map(|&v| AtomicF64::new(v))); (0..n_rows).into_par_iter().for_each(|i| { - let (level, val) = level_value(store, levels, q, value_fn, i); - atomic_buf[level].fetch_add(val, Ordering::Relaxed); + atomic_buf[lvl(i)].fetch_add(value_fn(i), Ordering::Relaxed); }); for (d, a) in slice.iter_mut().zip(atomic_buf.iter()) { *d = a.load(Ordering::Relaxed); @@ -353,3 +257,93 @@ fn scatter_add_single_factor( } } } + +// =========================================================================== +// DesignOperator — D, optionally rescaled by W^{1/2} +// =========================================================================== + +/// Rectangular design operator: `D` (unweighted) or `W^{1/2} D` (weighted). +/// +/// `apply` = `D x` / `W^{1/2} D x` (gather), `apply_adjoint` = `D^T x` / +/// `D^T W^{1/2} x` (scatter). For the weighted variant, the normal equations +/// `A^T A = D^T W D = G` recover the Gramian, so the same Schwarz +/// preconditioner approximating `G^{-1}` applies. Pass `None` to +/// [`DesignOperator::new`] for `D`, or `Some(&w)` for `W^{1/2} D`. The branch +/// on weights is hoisted outside the per-row loop — the weighted finalize is +/// fused into the last gather sweep, and the adjoint multiplies inline through +/// a closure, so there is no scratch buffer. +pub struct DesignOperator<'a, S: Store> { + design: &'a Design, + sqrt_weights: Option>, +} + +impl<'a, S: Store> DesignOperator<'a, S> { + /// Wrap a design matrix as a linear operator. + /// + /// Pass `None` for `D`, `Some(&w)` for `W^{1/2} D` (then `w.len()` must + /// equal `design.n_rows`). Precomputes and stores `sqrt(W)` when weights + /// are present. + /// + /// # Panics + /// + /// Panics when `weights.is_some()` and `weights.unwrap().len()` does not + /// equal `design.n_rows`. The `Solver` entry points perform fallible + /// validation against `WithinError::WeightCountMismatch` before + /// construction, so callers that go through `Solver::from_design` or + /// `solve()` never trigger this panic. + pub fn new(design: &'a Design, weights: Option<&[f64]>) -> Self { + let sqrt_weights = weights.map(|w| { + assert_eq!( + w.len(), + design.n_rows, + "weights length {} does not match design.n_rows {}", + w.len(), + design.n_rows + ); + w.iter().map(|wi| wi.sqrt()).collect() + }); + Self { + design, + sqrt_weights, + } + } + + /// Compute the observation-space RHS `b = W^{1/2} y`. + /// + /// For unweighted designs, returns a copy of `y`. + pub fn weighted_rhs(&self, y: &[f64]) -> Vec { + match &self.sqrt_weights { + None => y.to_vec(), + Some(sw) => y.iter().zip(sw).map(|(&yi, &swi)| swi * yi).collect(), + } + } +} + +impl Operator for DesignOperator<'_, S> { + fn nrows(&self) -> usize { + self.design.n_rows + } + + fn ncols(&self) -> usize { + self.design.n_dofs + } + + fn apply(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { + match &self.sqrt_weights { + Some(sw) => gather_apply(self.design, x, y, |i, s| sw[i] * s), + None => gather_apply(self.design, x, y, |_, s| s), + } + Ok(()) + } + + fn apply_adjoint(&self, x: &[f64], y: &mut [f64]) -> Result<(), schwarz_precond::SolveError> { + debug_assert_eq!(x.len(), self.design.n_rows); + debug_assert_eq!(y.len(), self.design.n_dofs); + y.fill(0.0); + match &self.sqrt_weights { + Some(sw) => scatter_apply(self.design, y, |i| sw[i] * x[i]), + None => scatter_apply(self.design, y, |i| x[i]), + } + Ok(()) + } +} diff --git a/crates/within/src/operator/tests.rs b/crates/within/src/operator/tests.rs index 7de0844..793083f 100644 --- a/crates/within/src/operator/tests.rs +++ b/crates/within/src/operator/tests.rs @@ -42,7 +42,7 @@ mod design_tests { } #[test] - fn test_matvec_d() { + fn test_apply_unweighted_values() { let schema = make_test_design(); let op = DesignOperator::new(&schema, None); let x = vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0]; @@ -52,7 +52,7 @@ mod design_tests { } #[test] - fn test_rmatvec_dt() { + fn test_apply_adjoint_unweighted_values() { let schema = make_test_design(); let op = DesignOperator::new(&schema, None); let r = vec![1.0, 2.0, 3.0, 4.0, 5.0]; diff --git a/crates/within/tests/domain.rs b/crates/within/tests/domain.rs index d9a36e3..6e7ef7d 100644 --- a/crates/within/tests/domain.rs +++ b/crates/within/tests/domain.rs @@ -57,7 +57,7 @@ fn make_large_design() -> Design { } #[test] -fn test_large_design_adjoint_property_matvec_d_rmatvec_dt() { +fn test_large_design_adjoint_property() { // Verify == for random-looking deterministic vectors. let dm = make_large_design(); let n_dofs = dm.n_dofs; @@ -105,7 +105,7 @@ fn test_large_design_matvec_correctness() { } #[test] -fn test_large_design_rmatvec_dt_correctness() { +fn test_large_design_apply_adjoint_correctness() { // D^T·1 should equal the per-level observation count for each factor. let dm = make_large_design(); let n_dofs = dm.n_dofs; From 75573ab5b02efb0170d98cb01f70045d96795ef9 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 12:17:32 +0200 Subject: [PATCH 5/6] build_preconditioner validates weight length; clean stale doc strings build_preconditioner is a public entry point that takes Option<&[f64]> weights without going through Solver::from_design. Direct callers could pass mismatched-length weights and trigger an index-out-of-bounds panic deep inside accumulate_cross_block. Add validate_weights at the top so the failure surfaces as the usual WithinError::WeightCountMismatch. Drop residual matvec_d / rmatvec_dt / rmatvec_wdt references in tests/domain.rs doc comments and test names. --- crates/within/src/operator/preconditioner.rs | 3 ++- crates/within/tests/domain.rs | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/within/src/operator/preconditioner.rs b/crates/within/src/operator/preconditioner.rs index a7c00bc..36dd8d2 100644 --- a/crates/within/src/operator/preconditioner.rs +++ b/crates/within/src/operator/preconditioner.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use crate::config::Preconditioner; use crate::domain::Design; -use crate::observation::Store; +use crate::observation::{validate_weights, Store}; use crate::operator::schwarz::{build_additive_with_strategy, FeSchwarz}; use crate::WithinResult; @@ -102,6 +102,7 @@ pub fn build_preconditioner( ) -> WithinResult { use crate::domain::build_local_domains; + validate_weights(weights, design.n_rows)?; match config { Preconditioner::Additive(local, reduction) => { let domains = build_local_domains(design, weights); diff --git a/crates/within/tests/domain.rs b/crates/within/tests/domain.rs index 6e7ef7d..99aa911 100644 --- a/crates/within/tests/domain.rs +++ b/crates/within/tests/domain.rs @@ -99,7 +99,7 @@ fn test_large_design_matvec_correctness() { let expected = if i % 50 == 0 { 1.0 } else { 0.0 }; assert_eq!( yi, expected, - "matvec_d(e_0)[{i}]: expected {expected}, got {yi}" + "D·e_0 at row {i}: expected {expected}, got {yi}" ); } } @@ -133,7 +133,8 @@ proptest! { #![proptest_config(ProptestConfig::with_cases(10))] /// The adjoint property must hold for random designs: - /// == (i.e., == ) + /// == , with D^T·W·r computed via + /// DesignOperator::apply_adjoint(W^{1/2} r) = D^T W^{1/2} (W^{1/2} r) = D^T W r. #[test] fn prop_weighted_adjoint_property( n_obs in 20usize..=200, @@ -405,7 +406,7 @@ fn test_single_factor_design_solve_without_precond() { } #[test] -fn test_single_factor_matvec_d_values() { +fn test_single_factor_apply_values() { // D·[a, b, c] with levels [0,1,2,0,1] should give [a, b, c, a, b] let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); @@ -418,7 +419,7 @@ fn test_single_factor_matvec_d_values() { } #[test] -fn test_single_factor_rmatvec_dt_values() { +fn test_single_factor_apply_adjoint_values() { // D^T·[1,2,3,4,5] with levels [0,1,2,0,1] should give [1+4, 2+5, 3] = [5, 7, 3] let categories = vec![vec![0u32, 1, 2, 0, 1]]; let store = FactorMajorStore::new(categories, 5).expect("valid store"); From ca1547f6cce9f5990a0901006b6116d50e3727a0 Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 11 May 2026 12:49:11 +0200 Subject: [PATCH 6/6] changelog: note DesignOperator scratch removal and build_preconditioner validation Two behavioral additions from the codex-review follow-ups (4f1410f, 75573ab) weren't captured in the #28 changelog entries: - DesignOperator::new now asserts weight length matches design.n_rows, the Mutex> scratch field is removed, and weighted apply fuses the sqrt(W) multiply into the last gather pass. - build_preconditioner now returns WithinError::WeightCountMismatch instead of panicking on out-of-bounds access in CrossTab assembly when weights are mis-sized. --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45f38c4..c82abe0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,14 @@ Modified LSMR is now the sole iterative solver, replacing CG and GMRES. `rmatvec_dt`, `rmatvec_wdt`, `gramian_diagonal`, and `uid_weight` methods are removed — use `DesignOperator::new(&design, weights).apply` / `apply_adjoint` instead. +- `DesignOperator::new` validates that `weights.len() == design.n_rows` + and panics on mismatch; the scratch `Mutex>` field is gone and + weighted `apply` / `apply_adjoint` no longer allocate. The weighted + `apply` fuses the `W^{1/2}` multiply into the last gather pass, so + there is no trailing scale loop. +- `build_preconditioner` now returns `WithinError::WeightCountMismatch` + for wrong-length weights instead of panicking on out-of-bounds access + inside `CrossTab` assembly. ### Removed