From 0d2917dd280398085e47e3ef414bea3e855359fe Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 1 Oct 2025 15:25:52 +0100 Subject: [PATCH 01/12] failing nlvs hessian test --- tests/firedrake/adjoint/test_hessian.py | 72 ++++++++++++------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/tests/firedrake/adjoint/test_hessian.py b/tests/firedrake/adjoint/test_hessian.py index 294680a554..072090b5ea 100644 --- a/tests/firedrake/adjoint/test_hessian.py +++ b/tests/firedrake/adjoint/test_hessian.py @@ -238,24 +238,24 @@ def test_dirichlet(rg): def test_burgers(solve_type, rg): tape = Tape() set_working_tape(tape) - n = 100 - mesh = UnitIntervalMesh(n) - V = FunctionSpace(mesh, "CG", 2) + nx = 50 + nt = 5 + mesh = UnitIntervalMesh(nx) + V = FunctionSpace(mesh, "CG", 1) - def Dt(u, u_, timestep): - return (u - u_)/timestep + def Dt(u, u_, dt): + return (u - u_)/dt x, = SpatialCoordinate(mesh) - pr = project(sin(2*pi*x), V, annotate=False) - ic = Function(V).assign(pr) + ic = Function(V).project(sin(2*pi*x)) - u_ = Function(V) - u = Function(V) + u_ = Function(V).assign(ic) + u = Function(V).assign(ic) v = TestFunction(V) - nu = Constant(0.0001) + nu = Constant(1/100) - timestep = Constant(1.0/n) + dt = Constant(1/nx) params = { 'snes_rtol': 1e-10, @@ -263,10 +263,9 @@ def Dt(u, u_, timestep): 'pc_type': 'lu', } - F = (Dt(u, ic, timestep)*v + F = (Dt(u, u_, dt)*v + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx bc = DirichletBC(V, 0.0, "on_boundary") - t = 0.0 if solve_type == "nlvs": use_nlvs = True @@ -280,39 +279,36 @@ def Dt(u, u_, timestep): NonlinearVariationalProblem(F, u), solver_parameters=params) - if use_nlvs: - solver.solve() - else: - solve(F == 0, u, bc, solver_parameters=params) - u_.assign(u) - t += float(timestep) - - F = (Dt(u, u_, timestep)*v - + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx - - end = 0.2 - while (t <= end): + for _ in range(nt): if use_nlvs: solver.solve() else: solve(F == 0, u, bc, solver_parameters=params) u_.assign(u) - t += float(timestep) - J = assemble(u_*u_*dx + ic*ic*dx) Jhat = ReducedFunctional(J, Control(ic)) + h = rg.uniform(V) g = ic.copy(deepcopy=True) - J.block_variable.adj_value = 1.0 - ic.block_variable.tlm_value = h - tape.evaluate_adj() - tape.evaluate_tlm() - - J.block_variable.hessian_value = 0 - tape.evaluate_hessian() - - dJdm = J.block_variable.tlm_value - Hm = ic.block_variable.hessian_value.dat.inner(h.dat) - assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 + print(f"{norm(h) = }") + + taylor = taylor_to_dict(Jhat, g, h) + from pprint import pprint + pprint(taylor) + assert min(taylor['R0']['Rate']) > 0.95, taylor['R0'] + assert min(taylor['R1']['Rate']) > 1.95, taylor['R1'] + assert min(taylor['R2']['Rate']) > 2.95, taylor['R2'] + + # J.block_variable.adj_value = 1.0 + # ic.block_variable.tlm_value = h + # tape.evaluate_adj() + # tape.evaluate_tlm() + + # J.block_variable.hessian_value = 0 + # tape.evaluate_hessian() + + # dJdm = J.block_variable.tlm_value + # Hm = ic.block_variable.hessian_value.dat.inner(h.dat) + # assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 From 36f41dfcaaebf93c1ba4c2374e18616a27b97104 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 3 Oct 2025 10:07:38 +0100 Subject: [PATCH 02/12] cached nlvs block - recompute with coefficients --- firedrake/adjoint_utils/blocks/__init__.py | 2 +- firedrake/adjoint_utils/blocks/solving.py | 93 +++++++++++++++++-- firedrake/adjoint_utils/variational_solver.py | 78 +++++++++++++++- tests/firedrake/adjoint/test_nlvs.py | 90 ++++++++++++++++++ 4 files changed, 254 insertions(+), 9 deletions(-) create mode 100644 tests/firedrake/adjoint/test_nlvs.py diff --git a/firedrake/adjoint_utils/blocks/__init__.py b/firedrake/adjoint_utils/blocks/__init__.py index bf83b896cc..d293d4b5d9 100644 --- a/firedrake/adjoint_utils/blocks/__init__.py +++ b/firedrake/adjoint_utils/blocks/__init__.py @@ -1,5 +1,5 @@ from .assembly import AssembleBlock # NOQA F401 -from .solving import GenericSolveBlock, SolveLinearSystemBlock, \ +from .solving import CachedSolverBlock, GenericSolveBlock, SolveLinearSystemBlock, \ ProjectBlock, SupermeshProjectBlock, SolveVarFormBlock, \ NonlinearVariationalSolveBlock # NOQA F401 from .function import FunctionAssignBlock, FunctionMergeBlock, \ diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 2c1fb4f876..48f3162d29 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -25,10 +25,89 @@ def extract_subfunction(u, V): return u -class Solver(Enum): +class SolverType(Enum): """Enum for solver types.""" FORWARD = 0 ADJOINT = 1 + TLM = 2 + HESSIAN = 3 + +FORWARD = SolverType.FORWARD +ADJOINT = SolverType.ADJOINT +TLM = SolverType.TLM +HESSIAN = SolverType.HESSIAN + + +# @singledispatch((Coefficient, Constant, Cofunction)) +def update_dependency(replaced_dep, dep): + if not isinstance(replaced_dep, (firedrake.Coefficient, + firedrake.Constant, + firedrake.Cofunction)): + raise TypeError("Updating non-Function-y things not implemented yet.") + replaced_dep.assign(dep.saved_output) + + +class CachedSolverBlock(Block): + def __init__(self, func, cached_solvers, replaced_dependencies, ad_block_tag=None): + super().__init__(ad_block_tag=ad_block_tag) + + self.func = func + self.cached_solvers = cached_solvers + self.replaced_dependencies = replaced_dependencies + + def update_dependencies(self): + for replaced_dep, dep in zip(self.replaced_dependencies, + self.get_dependencies()): + update_dependency(replaced_dep, dep) + + def prepare_recompute_component(self, inputs, relevant_outputs): + return None + + def recompute_component(self, inputs, block_variable, idx, prepared): + self.update_dependencies() + + solver = self.cached_solvers[FORWARD] + solver.solve() + result = solver._problem.u.copy(deepcopy=True) + + if isinstance(block_variable.checkpoint, firedrake.Function): + result = block_variable.checkpoint.assign(result) + + return maybe_disk_checkpoint(result) + + def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): + return None + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): + self.update_dependencies() + self.update_tlm_dependencies() + + self.tlm_dFdm.zero() + for i, block_variable in enumerate(self.get_dependencies()): + if block_variable.tlm_value is None: + continue + if block_variable.output == self.func: + continue + + self.tlm_dFdm += self.tlm_dFdm_assemblers[i]() + + solver = self.cached_solvers[TLM] + solver._problem.u.zero() + solver.solve() + result = solver._problem.u.copy(deepcopy=True) + return result + + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): + pass + + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): + pass + + def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): + pass + + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): + pass class GenericSolveBlock(Block): @@ -656,12 +735,12 @@ def _adjoint_solve(self, dJdu, compute_bdy): and self._ad_solvers["update_adjoint"] ): # Update left hand side of the adjoint equation. - self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solver_replace_forms(SolverType.ADJOINT) self._ad_solvers["adjoint_lvs"].invalidate_jacobian() self._ad_solvers["update_adjoint"] = False elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: # Update left hand side of the adjoint equation. - self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solver_replace_forms(SolverType.ADJOINT) # Update the right hand side of the adjoint equation. # problem.F._component[1] is the right hand side of the adjoint. @@ -679,7 +758,7 @@ def _adjoint_solve(self, dJdu, compute_bdy): return u_sol, adj_sol_bdy def _ad_assign_map(self, form, solver): - if solver == Solver.FORWARD: + if solver == SolverType.FORWARD: count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map else: count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map @@ -697,7 +776,7 @@ def _ad_assign_map(self, form, solver): block_variable.saved_output if ( - solver == Solver.ADJOINT + solver == SolverType.ADJOINT and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian ): block_variable = self.get_outputs()[0] @@ -712,8 +791,8 @@ def _ad_assign_coefficients(self, form, solver): for coeff, value in assign_map.items(): coeff.assign(value) - def _ad_solver_replace_forms(self, solver=Solver.FORWARD): - if solver == Solver.FORWARD: + def _ad_solver_replace_forms(self, solver=SolverType.FORWARD): + if solver == SolverType.FORWARD: problem = self._ad_solvers["forward_nlvs"]._problem self._ad_assign_coefficients(problem.F, solver) self._ad_assign_coefficients(problem.J, solver) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index d1b0af22ca..d6e4f740da 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -1,7 +1,7 @@ import copy from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations -from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock +from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock from firedrake.ufl_expr import derivative, adjoint from ufl import replace @@ -52,10 +52,86 @@ def wrapper(self, problem, *args, **kwargs): "recompute_count": 0} self._ad_adj_cache = {} + self._ad_solver_cache = {} + return wrapper + def _ad_cache_forward_solver(self): + from firedrake import ( + Function, Cofunction, + NonlinearVariationalProblem, + NonlinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD + + problem = self._ad_problem + + F = problem.F + replace_map = {} + for old_coeff in F.coefficients(): + if isinstance(old_coeff, Function) and old_coeff.ufl_element().family() == "Real": + new_coeff = copy.deepcopy(old_coeff) + else: + new_coeff = old_coeff.copy(deepcopy=True) + replace_map[old_coeff] = new_coeff + + Fnew = replace(F, replace_map) + unew = replace_map[problem.u] + + for cnew in replace_map.values(): + assert cnew in Fnew.coefficients() + for cold in replace_map.keys(): + assert cold not in Fnew.coefficients() + + # assert False + + nlvp = NonlinearVariationalProblem(Fnew, unew) + nlvs = NonlinearVariationalSolver(nlvp) + + self.dependencies_to_add = tuple(replace_map.keys()) + self.replaced_dependencies = tuple(replace_map.values()) + self._ad_solver_cache[FORWARD] = nlvs + @staticmethod def _ad_annotate_solve(solve): + @wraps(solve) + def wrapper(self, **kwargs): + """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the + Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic + for the purposes of the adjoint computation (such as projecting fields to other function spaces + for the purposes of visualisation).""" + from firedrake import LinearVariationalSolver + from firedrake.adjoint_utils.blocks.solving import FORWARD + annotate = annotate_tape(kwargs) + if annotate: + if kwargs.pop("bounds", None) is not None: + raise ValueError( + "MissingMathsError: we do not know how to differentiate through a variational inequality") + + if FORWARD not in self._ad_solver_cache: + self._ad_cache_forward_solver() + + block = CachedSolverBlock(self._ad_problem.u + self._ad_solver_cache, + self.replaced_dependencies, + ad_block_tag=self.ad_block_tag) + + for dep in self.dependencies_to_add: + block.add_dependency(dep, no_duplicates=True) + + get_working_tape().add_block(block) + + with stop_annotating(): + out = solve(self, **kwargs) + + if annotate: + block.add_output(self._ad_problem._ad_u.create_block_variable()) + + return out + + return wrapper + + @staticmethod + def _ad_annotate_solve_old(solve): @wraps(solve) def wrapper(self, **kwargs): """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py new file mode 100644 index 0000000000..c4d8acbb48 --- /dev/null +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -0,0 +1,90 @@ +import pytest + +from firedrake import * +from firedrake.adjoint import * + + +@pytest.fixture(autouse=True) +def handle_taping(): + yield + tape = get_working_tape() + tape.clear_tape() + + +@pytest.fixture(autouse=True, scope="module") +def handle_annotation(): + if not annotate_tape(): + continue_annotation() + yield + # Ensure annotation is paused when we finish. + if annotate_tape(): + pause_annotation() + + +def forward(ic, dt, nt): + """Burgers equation solver.""" + V = ic.function_space() + R = dt.function_space() + + nu = Function(R).assign(0.1) + u0 = Function(V) + u1 = Function(V) + v = TestFunction(V) + + F = ((u1 - u0)*v + + dt*u1*u1.dx(0)*v + + dt*nu*u1.dx(0)*v.dx(0))*dx + + problem = NonlinearVariationalProblem(F, u1) + solver = NonlinearVariationalSolver(problem) + + u1.assign(ic) + + for i in range(nt): + u0.assign(u1) + solver.solve() + nu += dt + + J = assemble(u1*u1*dx) + return J + + +@pytest.mark.skipcomplex +def test_nlvs_adjoint(): + mesh = UnitIntervalMesh(8) + x, = SpatialCoordinate(mesh) + + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + + nt = 2 + dt = Function(R).assign(0.1) + ic = Function(V).interpolate(cos(2*pi*x)) + + ctype = 'ic' + + if ctype == 'ic': + control = ic + elif ctype == 'dt': + control = dt + else: + raise ValueError + + continue_annotation() + with set_working_tape() as tape: + J = forward(ic, dt, nt) + Jhat = ReducedFunctional(J, Control(control), tape=tape) + pause_annotation() + + if ctype == 'ic': + m = Function(V).assign(0.5*ic) + h = Function(V).interpolate(-0.5*cos(4*pi*x)) + assert abs(Jhat(m) - forward(m, dt, nt)) < 1e-14 + + elif ctype == 'dt': + m = Function(R).assign(0.05) + h = Function(R).assign(0.01) + assert abs(Jhat(m) - forward(ic, m, nt)) < 1e-14 + +if __name__ == "__main__": + test_nlvs_adjoint() From 8876a06b0ec0df38042ca3fcb859f0b7ca968105 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 3 Oct 2025 17:03:05 +0100 Subject: [PATCH 03/12] cached nlvs block - evaluate_tlm with coefficients --- firedrake/adjoint_utils/blocks/solving.py | 60 ++++++++++------ firedrake/adjoint_utils/variational_solver.py | 72 ++++++++++++++----- tests/firedrake/adjoint/test_nlvs.py | 41 +++++++---- 3 files changed, 121 insertions(+), 52 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 48f3162d29..88543b0e79 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -32,6 +32,7 @@ class SolverType(Enum): TLM = 2 HESSIAN = 3 + FORWARD = SolverType.FORWARD ADJOINT = SolverType.ADJOINT TLM = SolverType.TLM @@ -44,27 +45,46 @@ def update_dependency(replaced_dep, dep): firedrake.Constant, firedrake.Cofunction)): raise TypeError("Updating non-Function-y things not implemented yet.") - replaced_dep.assign(dep.saved_output) + replaced_dep.assign(dep) class CachedSolverBlock(Block): - def __init__(self, func, cached_solvers, replaced_dependencies, ad_block_tag=None): + def __init__(self, func, cached_solvers, + replaced_dependencies, + tlm_rhs, replaced_tlms, tlm_dFdm_forms, + ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) self.func = func self.cached_solvers = cached_solvers self.replaced_dependencies = replaced_dependencies - def update_dependencies(self): + self.tlm_rhs = tlm_rhs + self.replaced_tlms = replaced_tlms + self.tlm_dFdm_forms = tlm_dFdm_forms + + def update_dependencies(self, use_output=False): for replaced_dep, dep in zip(self.replaced_dependencies, self.get_dependencies()): - update_dependency(replaced_dep, dep) + update_dependency(replaced_dep, dep.saved_output) + if use_output: + output = self.get_outputs()[0].saved_output + self.cached_solvers[FORWARD]._problem.u.assign(output) + + def update_tlm_dependencies(self): + for replaced_tlm, dep in zip(self.replaced_tlms, + self.get_dependencies()): + if dep.output == self.func: + continue + if dep.tlm_value is None: + continue + update_dependency(replaced_tlm, dep.tlm_value) def prepare_recompute_component(self, inputs, relevant_outputs): return None def recompute_component(self, inputs, block_variable, idx, prepared): - self.update_dependencies() + self.update_dependencies(use_output=False) solver = self.cached_solvers[FORWARD] solver.solve() @@ -79,17 +99,16 @@ def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): return None def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): - self.update_dependencies() + self.update_dependencies(use_output=True) self.update_tlm_dependencies() - self.tlm_dFdm.zero() - for i, block_variable in enumerate(self.get_dependencies()): - if block_variable.tlm_value is None: + self.tlm_rhs.zero() + for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): + if dep.tlm_value is None: continue - if block_variable.output == self.func: + if dep.output is self.func: continue - - self.tlm_dFdm += self.tlm_dFdm_assemblers[i]() + self.tlm_rhs += firedrake.assemble(dFdm) solver = self.cached_solvers[TLM] solver._problem.u.zero() @@ -352,6 +371,10 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, return dFdm def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): + pass + + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, + prepared=None): fwd_block_variable = self.get_outputs()[0] u = fwd_block_variable.output @@ -363,16 +386,6 @@ def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): fwd_block_variable.saved_output, firedrake.TrialFunction(u.function_space()) ) - - return { - "form": F_form, - "dFdu": dFdu - } - - def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, - prepared=None): - F_form = prepared["form"] - dFdu = prepared["dFdu"] V = self.get_outputs()[idx].output.function_space() bcs = [] @@ -409,10 +422,11 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, dFdm = ufl.algorithms.expand_derivatives(dFdm) dFdm = firedrake.assemble(dFdm) dudm = firedrake.Function(V) - return self._assemble_and_solve_tlm_eq( + result = self._assemble_and_solve_tlm_eq( firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs), dFdm, dudm, bcs ) + return result def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs): return self._assembled_solve(dFdu, dFdm, dudm, bcs) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index d6e4f740da..b287e7da43 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -3,7 +3,7 @@ from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock from firedrake.ufl_expr import derivative, adjoint -from ufl import replace +import ufl class NonlinearVariationalProblemMixin: @@ -58,8 +58,7 @@ def wrapper(self, problem, *args, **kwargs): def _ad_cache_forward_solver(self): from firedrake import ( - Function, Cofunction, - NonlinearVariationalProblem, + Function, NonlinearVariationalProblem, NonlinearVariationalSolver) from firedrake.adjoint_utils.blocks.solving import FORWARD @@ -74,7 +73,7 @@ def _ad_cache_forward_solver(self): new_coeff = old_coeff.copy(deepcopy=True) replace_map[old_coeff] = new_coeff - Fnew = replace(F, replace_map) + Fnew = ufl.replace(F, replace_map) unew = replace_map[problem.u] for cnew in replace_map.values(): @@ -82,15 +81,52 @@ def _ad_cache_forward_solver(self): for cold in replace_map.keys(): assert cold not in Fnew.coefficients() - # assert False - nlvp = NonlinearVariationalProblem(Fnew, unew) nlvs = NonlinearVariationalSolver(nlvp) - self.dependencies_to_add = tuple(replace_map.keys()) - self.replaced_dependencies = tuple(replace_map.values()) + self._ad_dependencies_to_add = tuple(replace_map.keys()) + self._ad_replaced_dependencies = tuple(replace_map.values()) self._ad_solver_cache[FORWARD] = nlvs + def _ad_cache_tlm_solver(self): + from firedrake import ( + Function, Cofunction, derivative, TrialFunction, + LinearVariationalProblem, LinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD, TLM + + nlvp = self._ad_solver_cache[FORWARD]._problem + + F = nlvp.F + u = nlvp.u + V = u.function_space() + + dFdu = derivative(F, u, TrialFunction(V)) + dFdm = Cofunction(V.dual()) + dudm = Function(V) + + lvp = LinearVariationalProblem(dFdu, dFdm, dudm) + lvs = LinearVariationalSolver(lvp) + + self._ad_solver_cache[TLM] = lvs + self._ad_tlm_rhs = dFdm + + replaced_tlms = [] + dFdm_tlm_forms = [] + for m in self._ad_replaced_dependencies: + if isinstance(m, Function) and m.ufl_element().family() == "Real": + mtlm = copy.deepcopy(m) + else: + mtlm = m.copy(deepcopy=True) + + replaced_tlms.append(mtlm) + + dFdm = derivative(-F, m, mtlm) + dFdm = ufl.algorithms.expand_derivatives(dFdm) + dFdm_tlm_forms.append(dFdm) + + self._ad_tlm_dFdm_forms = dFdm_tlm_forms + self._ad_replaced_tlms = replaced_tlms + @staticmethod def _ad_annotate_solve(solve): @wraps(solve) @@ -99,23 +135,25 @@ def wrapper(self, **kwargs): Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" - from firedrake import LinearVariationalSolver - from firedrake.adjoint_utils.blocks.solving import FORWARD annotate = annotate_tape(kwargs) if annotate: if kwargs.pop("bounds", None) is not None: raise ValueError( "MissingMathsError: we do not know how to differentiate through a variational inequality") - if FORWARD not in self._ad_solver_cache: + if len(self._ad_solver_cache) == 0: self._ad_cache_forward_solver() + self._ad_cache_tlm_solver() - block = CachedSolverBlock(self._ad_problem.u + block = CachedSolverBlock(self._ad_problem.u, self._ad_solver_cache, - self.replaced_dependencies, + self._ad_replaced_dependencies, + self._ad_tlm_rhs, + self._ad_replaced_tlms, + self._ad_tlm_dFdm_forms, ad_block_tag=self.ad_block_tag) - for dep in self.dependencies_to_add: + for dep in self._ad_dependencies_to_add: block.add_dependency(dep, no_duplicates=True) get_working_tape().add_block(block) @@ -201,10 +239,10 @@ def _ad_problem_clone(self, problem, dependencies): from firedrake import NonlinearVariationalProblem _ad_count_map, J_replace_map, F_replace_map = self._build_count_map( problem.J, dependencies, F=problem.F) - nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), + nlvp = NonlinearVariationalProblem(ufl.replace(problem.F, F_replace_map), F_replace_map[problem.u_restrict], bcs=problem.bcs, - J=replace(problem.J, J_replace_map)) + J=ufl.replace(problem.J, J_replace_map)) nlvp.is_linear = problem.is_linear nlvp._constant_jacobian = problem._constant_jacobian nlvp._ad_count_map_update(_ad_count_map) @@ -229,7 +267,7 @@ def _ad_adj_lvs_problem(self, block, adj_F): _ad_count_map, J_replace_map, _ = self._build_count_map( adj_F, block._dependencies) lvp = LinearVariationalProblem( - replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + ufl.replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, bcs=tmp_problem.bcs, constant_jacobian=self._ad_problem._constant_jacobian) lvp._ad_count_map_update(_ad_count_map) diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index c4d8acbb48..2f789af5c6 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -24,9 +24,12 @@ def handle_annotation(): def forward(ic, dt, nt): """Burgers equation solver.""" V = ic.function_space() - R = dt.function_space() - nu = Function(R).assign(0.1) + if isinstance(dt, Constant): + nu = Constant(0.1) + else: + nu = Function(dt.function_space()).assign(0.1) + u0 = Function(V) u1 = Function(V) v = TestFunction(V) @@ -43,29 +46,30 @@ def forward(ic, dt, nt): for i in range(nt): u0.assign(u1) solver.solve() - nu += dt + if not isinstance(nu, Constant): + nu += dt J = assemble(u1*u1*dx) return J @pytest.mark.skipcomplex -def test_nlvs_adjoint(): +@pytest.mark.parametrize("control_type", ["ic_control", "dt_control"]) +def test_nlvs_adjoint(control_type): mesh = UnitIntervalMesh(8) x, = SpatialCoordinate(mesh) V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 2 + nt = 4 dt = Function(R).assign(0.1) + # dt = Constant(0.1) ic = Function(V).interpolate(cos(2*pi*x)) - ctype = 'ic' - - if ctype == 'ic': + if control_type == 'ic_control': control = ic - elif ctype == 'dt': + elif control_type == 'dt_control': control = dt else: raise ValueError @@ -76,15 +80,28 @@ def test_nlvs_adjoint(): Jhat = ReducedFunctional(J, Control(control), tape=tape) pause_annotation() - if ctype == 'ic': + if control_type == 'ic_control': m = Function(V).assign(0.5*ic) h = Function(V).interpolate(-0.5*cos(4*pi*x)) + + # recompute component assert abs(Jhat(m) - forward(m, dt, nt)) < 1e-14 - elif ctype == 'dt': + # tlm + assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + elif control_type == 'dt_control': m = Function(R).assign(0.05) h = Function(R).assign(0.01) + + # recompute component assert abs(Jhat(m) - forward(ic, m, nt)) < 1e-14 + # tlm + assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + if __name__ == "__main__": - test_nlvs_adjoint() + ctype = "ic" + print(f"Control type: {ctype}") + test_nlvs_adjoint(f"{ctype}_control") From 2f9cf1a9541b48753acdd2afd3533f48d3f961d1 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 3 Oct 2025 19:58:03 +0100 Subject: [PATCH 04/12] cached nlvs block - recompute and evaluate_tlm with bcs --- firedrake/adjoint_utils/blocks/solving.py | 42 +++++++-- firedrake/adjoint_utils/variational_solver.py | 19 +++- tests/firedrake/adjoint/test_nlvs.py | 87 +++++++++++++------ 3 files changed, 113 insertions(+), 35 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 88543b0e79..ab33a86703 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -41,21 +41,32 @@ class SolverType(Enum): # @singledispatch((Coefficient, Constant, Cofunction)) def update_dependency(replaced_dep, dep): - if not isinstance(replaced_dep, (firedrake.Coefficient, - firedrake.Constant, - firedrake.Cofunction)): - raise TypeError("Updating non-Function-y things not implemented yet.") - replaced_dep.assign(dep) + if isinstance(replaced_dep, (firedrake.Coefficient, + firedrake.Constant, + firedrake.Cofunction)): + replaced_dep.assign(dep) + + elif isinstance(replaced_dep, firedrake.DirichletBC): + if dep is None: + replaced_dep.set_value(0) + else: + replaced_dep.set_value(dep.function_arg) + + else: + raise TypeError( + "Updating not implemented for adjoint " + f" dependency of type {type(replaced_dep)}") class CachedSolverBlock(Block): - def __init__(self, func, cached_solvers, + def __init__(self, func, bcs, cached_solvers, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) self.func = func + self.bcs = bcs self.cached_solvers = cached_solvers self.replaced_dependencies = replaced_dependencies @@ -67,10 +78,19 @@ def update_dependencies(self, use_output=False): for replaced_dep, dep in zip(self.replaced_dependencies, self.get_dependencies()): update_dependency(replaced_dep, dep.saved_output) + if use_output: output = self.get_outputs()[0].saved_output self.cached_solvers[FORWARD]._problem.u.assign(output) + idx = 0 + for dep in self.get_dependencies(): + if isinstance(dep.output, firedrake.DirichletBC): + bc_arg = dep.saved_output.function_arg + self.bcs[idx].set_value(bc_arg) + idx += 1 + assert idx == len(self.bcs) + def update_tlm_dependencies(self): for replaced_tlm, dep in zip(self.replaced_tlms, self.get_dependencies()): @@ -80,6 +100,16 @@ def update_tlm_dependencies(self): continue update_dependency(replaced_tlm, dep.tlm_value) + idx = 0 + for dep in self.get_dependencies(): + if isinstance(dep.output, firedrake.DirichletBC): + if dep.tlm_value is None: + bc_val = 0 + else: + bc_val = dep.tlm_value.function_arg + self.bcs[idx].set_value(bc_val) + idx += 1 + def prepare_recompute_component(self, inputs, relevant_outputs): return None diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index b287e7da43..3b478c5ffb 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -58,7 +58,8 @@ def wrapper(self, problem, *args, **kwargs): def _ad_cache_forward_solver(self): from firedrake import ( - Function, NonlinearVariationalProblem, + Function, DirichletBC, + NonlinearVariationalProblem, NonlinearVariationalSolver) from firedrake.adjoint_utils.blocks.solving import FORWARD @@ -81,10 +82,19 @@ def _ad_cache_forward_solver(self): for cold in replace_map.keys(): assert cold not in Fnew.coefficients() - nlvp = NonlinearVariationalProblem(Fnew, unew) + bcs = problem.bcs + bcs_new = [ + DirichletBC(V=bc.function_space(), + g=bc.function_arg, + sub_domain=bc.sub_domain) + for bc in bcs + ] + + nlvp = NonlinearVariationalProblem(Fnew, unew, bcs=bcs_new) nlvs = NonlinearVariationalSolver(nlvp) - self._ad_dependencies_to_add = tuple(replace_map.keys()) + self._ad_bcs = bcs_new + self._ad_dependencies_to_add = tuple((*replace_map.keys(), *bcs)) self._ad_replaced_dependencies = tuple(replace_map.values()) self._ad_solver_cache[FORWARD] = nlvs @@ -104,7 +114,7 @@ def _ad_cache_tlm_solver(self): dFdm = Cofunction(V.dual()) dudm = Function(V) - lvp = LinearVariationalProblem(dFdu, dFdm, dudm) + lvp = LinearVariationalProblem(dFdu, dFdm, dudm, bcs=self._ad_bcs) lvs = LinearVariationalSolver(lvp) self._ad_solver_cache[TLM] = lvs @@ -146,6 +156,7 @@ def wrapper(self, **kwargs): self._ad_cache_tlm_solver() block = CachedSolverBlock(self._ad_problem.u, + self._ad_bcs, self._ad_solver_cache, self._ad_replaced_dependencies, self._ad_tlm_rhs, diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index 2f789af5c6..957226a51f 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -21,14 +21,18 @@ def handle_annotation(): pause_annotation() -def forward(ic, dt, nt): +def forward(ic, dt, nt, bc_arg=None): """Burgers equation solver.""" V = ic.function_space() - if isinstance(dt, Constant): - nu = Constant(0.1) + if bc_arg: + bc_val = bc_arg.copy(deepcopy=True) + bc = DirichletBC(V, bc_val, 1) + # bc.apply(ic) else: - nu = Function(dt.function_space()).assign(0.1) + bc = None + + nu = Function(dt.function_space()).assign(0.1) u0 = Function(V) u1 = Function(V) @@ -38,7 +42,7 @@ def forward(ic, dt, nt): + dt*u1*u1.dx(0)*v + dt*nu*u1.dx(0)*v.dx(0))*dx - problem = NonlinearVariationalProblem(F, u1) + problem = NonlinearVariationalProblem(F, u1, bcs=bc) solver = NonlinearVariationalSolver(problem) u1.assign(ic) @@ -46,37 +50,59 @@ def forward(ic, dt, nt): for i in range(nt): u0.assign(u1) solver.solve() - if not isinstance(nu, Constant): - nu += dt + nu += dt + if bc_arg: + bc_val.assign(bc_val + dt) J = assemble(u1*u1*dx) return J @pytest.mark.skipcomplex -@pytest.mark.parametrize("control_type", ["ic_control", "dt_control"]) -def test_nlvs_adjoint(control_type): - mesh = UnitIntervalMesh(8) +@pytest.mark.parametrize("control_type", ["ic_control", + "dt_control", + "bc_control"]) +@pytest.mark.parametrize("bc_type", ["neumann_bc", + "dirichlet_bc"]) +def test_nlvs_adjoint(control_type, bc_type): + if control_type == 'bc_control' and bc_type == 'neumann_bc': + pytest.skip("Cannot use Neumann BCs as control") + + mesh = UnitIntervalMesh(10) x, = SpatialCoordinate(mesh) V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 4 + nt = 2 dt = Function(R).assign(0.1) - # dt = Constant(0.1) ic = Function(V).interpolate(cos(2*pi*x)) + dt0 = dt.copy(deepcopy=True) + ic0 = ic.copy(deepcopy=True) + + if bc_type == 'neumann_bc': + bc_arg = None + bc_arg0 = None + elif bc_type == 'dirichlet_bc': + bc_arg = Function(R).assign(1.) + bc_arg0 = bc_arg.copy(deepcopy=True) + else: + raise ValueError(f"Unrecognised {bc_type = }") + if control_type == 'ic_control': - control = ic + control = ic0 elif control_type == 'dt_control': - control = dt + control = dt0 + elif control_type == 'bc_control': + control = bc_arg0 else: - raise ValueError + raise ValueError(f"Unrecognised {control_type = }") + print("record tape") continue_annotation() with set_working_tape() as tape: - J = forward(ic, dt, nt) + J = forward(ic0, dt0, nt, bc_arg=bc_arg0) Jhat = ReducedFunctional(J, Control(control), tape=tape) pause_annotation() @@ -84,21 +110,32 @@ def test_nlvs_adjoint(control_type): m = Function(V).assign(0.5*ic) h = Function(V).interpolate(-0.5*cos(4*pi*x)) - # recompute component - assert abs(Jhat(m) - forward(m, dt, nt)) < 1e-14 - - # tlm - assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + ic2 = m.copy(deepcopy=True) + dt2 = dt + bc_arg2 = bc_arg elif control_type == 'dt_control': m = Function(R).assign(0.05) h = Function(R).assign(0.01) - # recompute component - assert abs(Jhat(m) - forward(ic, m, nt)) < 1e-14 + ic2 = ic + dt2 = m.copy(deepcopy=True) + bc_arg2 = bc_arg + + elif control_type == 'bc_control': + m = Function(R).assign(0.5) + h = Function(R).assign(-0.1) + + ic2 = ic + dt2 = dt + bc_arg2 = m.copy(deepcopy=True) + + # recompute component + print("recompute test") + assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 - # tlm - assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + # tlm + assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 if __name__ == "__main__": From 0d987df581e99abdcdf62e1c7b6ec927c1806bd5 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Mon, 6 Oct 2025 17:59:03 +0100 Subject: [PATCH 05/12] cached nlvs block - evaluate adj with coefficient or bc controls --- firedrake/adjoint_utils/blocks/solving.py | 199 +++++++++++------- firedrake/adjoint_utils/variational_solver.py | 141 ++++++++++++- tests/firedrake/adjoint/test_hessian.py | 13 -- tests/firedrake/adjoint/test_nlvs.py | 26 ++- 4 files changed, 274 insertions(+), 105 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index ab33a86703..a8723c0fce 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -39,29 +39,11 @@ class SolverType(Enum): HESSIAN = SolverType.HESSIAN -# @singledispatch((Coefficient, Constant, Cofunction)) -def update_dependency(replaced_dep, dep): - if isinstance(replaced_dep, (firedrake.Coefficient, - firedrake.Constant, - firedrake.Cofunction)): - replaced_dep.assign(dep) - - elif isinstance(replaced_dep, firedrake.DirichletBC): - if dep is None: - replaced_dep.set_value(0) - else: - replaced_dep.set_value(dep.function_arg) - - else: - raise TypeError( - "Updating not implemented for adjoint " - f" dependency of type {type(replaced_dep)}") - - class CachedSolverBlock(Block): def __init__(self, func, bcs, cached_solvers, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, + adj_rhs, replaced_adjs, adj_dFdm_forms, adj_residual, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) @@ -74,44 +56,72 @@ def __init__(self, func, bcs, cached_solvers, self.replaced_tlms = replaced_tlms self.tlm_dFdm_forms = tlm_dFdm_forms + self.adj_rhs = adj_rhs + self.replaced_adjs = replaced_adjs + self.adj_dFdm_forms = adj_dFdm_forms + self.adj_residual = adj_residual + + def _coefficient_dependencies(self, dependencies=None): + dependencies = dependencies or self.get_dependencies() + return dependencies[:len(self.replaced_dependencies)] + + def _bc_dependencies(self, dependencies=None): + dependencies = dependencies or self.get_dependencies() + if len(self.bcs) > 0: + return dependencies[-len(self.bcs):] + else: + return [] + def update_dependencies(self, use_output=False): + """Update all dependencies of the forward solve. + """ + # Update the coefficients in the form. + # Use the fact that zip will use the shorter length. for replaced_dep, dep in zip(self.replaced_dependencies, - self.get_dependencies()): - update_dependency(replaced_dep, dep.saved_output) - + self._coefficient_dependencies()): + replaced_dep.assign(dep.saved_output) + + # 1. For forward recomputation the unknown Function should use + # the incoming value of the dependency as the initial guess. + # 2. For the adjoint, TLM, and Hessian, the unknown Function + # should use the computed value so that the linearised + # Jacobian is correct. if use_output: output = self.get_outputs()[0].saved_output self.cached_solvers[FORWARD]._problem.u.assign(output) - idx = 0 - for dep in self.get_dependencies(): - if isinstance(dep.output, firedrake.DirichletBC): - bc_arg = dep.saved_output.function_arg - self.bcs[idx].set_value(bc_arg) - idx += 1 - assert idx == len(self.bcs) + # Update the boundary conditions + for replaced_dep, dep in zip(self.bcs, self._bc_dependencies()): + replaced_dep.set_value(dep.saved_output.function_arg) def update_tlm_dependencies(self): - for replaced_tlm, dep in zip(self.replaced_tlms, - self.get_dependencies()): - if dep.output == self.func: + """Update all dependencies of the tlm solve. + """ + for replaced_dep, dep in zip(self.replaced_tlms, + self._coefficient_dependencies()): + if dep.output == self.func: # TODO: and not self.linear continue - if dep.tlm_value is None: + if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - update_dependency(replaced_tlm, dep.tlm_value) + replaced_dep.assign(dep.tlm_value) - idx = 0 - for dep in self.get_dependencies(): - if isinstance(dep.output, firedrake.DirichletBC): - if dep.tlm_value is None: - bc_val = 0 - else: - bc_val = dep.tlm_value.function_arg - self.bcs[idx].set_value(bc_val) - idx += 1 + for replaced_dep, dep in zip(self.bcs, self._bc_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + bc_val = 0 + else: + bc_val = dep.tlm_value.function_arg + replaced_dep.set_value(bc_val) + + def update_adj_dependencies(self): + # TODO: Anything to do here? + pass + + def _compute_boundary(self, relevant_dependencies): + return any(isinstance(dep.output, firedrake.DirichletBC) + for _, dep in relevant_dependencies) def prepare_recompute_component(self, inputs, relevant_outputs): - return None + return def recompute_component(self, inputs, block_variable, idx, prepared): self.update_dependencies(use_output=False) @@ -120,26 +130,29 @@ def recompute_component(self, inputs, block_variable, idx, prepared): solver.solve() result = solver._problem.u.copy(deepcopy=True) + # Possibly checkpoint the result for the adjoint solve later. if isinstance(block_variable.checkpoint, firedrake.Function): result = block_variable.checkpoint.assign(result) return maybe_disk_checkpoint(result) def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): - return None + return def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): self.update_dependencies(use_output=True) self.update_tlm_dependencies() + # Assemble the rhs of (dF/du)(du/dm) = -dF/dm self.tlm_rhs.zero() for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): - if dep.tlm_value is None: + if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - if dep.output is self.func: + if dep.output is self.func: # Can't compute dependence on initial guess? continue self.tlm_rhs += firedrake.assemble(dFdm) + # Solve for dudm solver = self.cached_solvers[TLM] solver._problem.u.zero() solver.solve() @@ -147,10 +160,48 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar return result def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - pass + self.update_dependencies(use_output=True) + self.update_adj_dependencies() + + for bc in self.bcs: + bc.homogenize() + + solver = self.cached_solvers[ADJOINT] + adj_sol = solver._problem.u + + dJdu = adj_inputs[0] + self.adj_rhs.assign(dJdu) + adj_sol.zero() + + solver.solve() + + if self._compute_boundary(relevant_dependencies): + adj_sol_bc = firedrake.assemble(self.adj_residual) + adj_sol_bc = adj_sol_bc.riesz_representation("l2") + else: + adj_sol_bc = None + + prepared = { + "adj_sol": adj_sol, + "adj_sol_bc": adj_sol_bc + } + return prepared def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): - pass + if block_variable.output == self.func: # and not self.linear + return None + + if isinstance(block_variable.output, firedrake.DirichletBC): + bc = block_variable.output + adj_sol_bc = prepared["adj_sol_bc"] + return bc.reconstruct( + g=extract_subfunction(adj_sol_bc, bc.function_space()) + ) + + # assemble sensititivy comment + dFdm = firedrake.assemble(self.adj_dFdm_forms[idx]) + + return dFdm def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): pass @@ -327,31 +378,6 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): r["adj_sol_bdy"] = adj_sol_bdy return r - def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): - return firedrake.assemble(dFdu_adj_form, **kwargs) - - def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): - dJdu_copy = dJdu.copy() - # Homogenize and apply boundary conditions on adj_dFdu. - bcs = self._homogenize_bcs() - dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs) - - adj_sol = firedrake.Function(self.function_space) - firedrake.solve( - dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs - ) - - adj_sol_bdy = None - if compute_bdy: - adj_sol_bdy = self._compute_adj_bdy( - adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) - - return adj_sol, adj_sol_bdy - - def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): - adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) - return adj_sol_bdy.riesz_representation("l2") - def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if not self.linear and self.func == block_variable.output: @@ -400,6 +426,31 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs) return dFdm + def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): + return firedrake.assemble(dFdu_adj_form, **kwargs) + + def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): + dJdu_copy = dJdu.copy() + # Homogenize and apply boundary conditions on adj_dFdu. + bcs = self._homogenize_bcs() + dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs) + + adj_sol = firedrake.Function(self.function_space) + firedrake.solve( + dFdu, adj_sol, dJdu, *self.adj_args, **self.adj_kwargs + ) + + adj_sol_bdy = None + if compute_bdy: + adj_sol_bdy = self._compute_adj_bdy( + adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) + + return adj_sol, adj_sol_bdy + + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): + adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) + return adj_sol_bdy.riesz_representation("l2") + def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): pass diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 3b478c5ffb..0fde2bea3f 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -2,7 +2,7 @@ from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock -from firedrake.ufl_expr import derivative, adjoint +from firedrake.ufl_expr import derivative, adjoint, action import ufl @@ -65,23 +65,32 @@ def _ad_cache_forward_solver(self): problem = self._ad_problem + # Build a new form so that we can update the coefficient + # values without affecting user code. + # We do this by copying all coefficients in the form and + # symbolically replacing the old values with the new. F = problem.F replace_map = {} for old_coeff in F.coefficients(): if isinstance(old_coeff, Function) and old_coeff.ufl_element().family() == "Real": + # TODO: Does this still need to be special cased? new_coeff = copy.deepcopy(old_coeff) else: new_coeff = old_coeff.copy(deepcopy=True) replace_map[old_coeff] = new_coeff + # We need a handle to the new Function being + # solved for so that we can create an NLVS. Fnew = ufl.replace(F, replace_map) unew = replace_map[problem.u] - for cnew in replace_map.values(): - assert cnew in Fnew.coefficients() - for cold in replace_map.keys(): - assert cold not in Fnew.coefficients() - + # We also need to "replace" all the bcs in + # the new NLVS so we can modify those values + # without affecting user code. + # Note that ``DirichletBC.reconstruct`` will + # return ``self`` if V, g, and sub_domain are + # all unchanged, so we need to explicitly + # instantiate a new object. bcs = problem.bcs bcs_new = [ DirichletBC(V=bc.function_space(), @@ -90,40 +99,60 @@ def _ad_cache_forward_solver(self): for bc in bcs ] + # This NLVS will be used to recompute the solve. nlvp = NonlinearVariationalProblem(Fnew, unew, bcs=bcs_new) nlvs = NonlinearVariationalSolver(nlvp) - self._ad_bcs = bcs_new - self._ad_dependencies_to_add = tuple((*replace_map.keys(), *bcs)) + # The original coefficients will be added as + # dependencies to all solve blocks. + # The block need handles to the newly created + # objects to update their values when recomputing. + self._ad_dependencies_to_add = (*replace_map.keys(), *bcs) self._ad_replaced_dependencies = tuple(replace_map.values()) + self._ad_bcs = bcs_new self._ad_solver_cache[FORWARD] = nlvs + print(f"{replace_map.values()=}") + def _ad_cache_tlm_solver(self): from firedrake import ( Function, Cofunction, derivative, TrialFunction, LinearVariationalProblem, LinearVariationalSolver) from firedrake.adjoint_utils.blocks.solving import FORWARD, TLM + # If we build the TLM form from the cached + # forward solve form then we can update exactly + # the same coefficients/boundary conditions. nlvp = self._ad_solver_cache[FORWARD]._problem F = nlvp.F u = nlvp.u V = u.function_space() + # We need gradient of output/input i.e. du/dm. + # We know F(u; m) = 0 and _total_ dF/dm = 0. + # Then for the _partial_ derivatives: + # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: + # (dF/du)*(du/dm) = -dF/dm dFdu = derivative(F, u, TrialFunction(V)) dFdm = Cofunction(V.dual()) dudm = Function(V) + # Reuse the same bcs as the forward problem. + # TODO: Think about if we should use new bcs. lvp = LinearVariationalProblem(dFdu, dFdm, dudm, bcs=self._ad_bcs) lvs = LinearVariationalSolver(lvp) self._ad_solver_cache[TLM] = lvs self._ad_tlm_rhs = dFdm + # Do all the symbolic work for calculating dF/dm up front + # so we only pay for the numeric calculations at run time. replaced_tlms = [] dFdm_tlm_forms = [] for m in self._ad_replaced_dependencies: if isinstance(m, Function) and m.ufl_element().family() == "Real": + # TODO: Does this still need to be special cased? mtlm = copy.deepcopy(m) else: mtlm = m.copy(deepcopy=True) @@ -131,12 +160,101 @@ def _ad_cache_tlm_solver(self): replaced_tlms.append(mtlm) dFdm = derivative(-F, m, mtlm) + # TODO: Do we need expand_derivatives here? If so, why? dFdm = ufl.algorithms.expand_derivatives(dFdm) dFdm_tlm_forms.append(dFdm) + # We'll need to update the replaced_tlm + # values and assemble the dFdm forms self._ad_tlm_dFdm_forms = dFdm_tlm_forms self._ad_replaced_tlms = replaced_tlms + def _ad_cache_adj_solver(self): + from firedrake import ( + Function, Cofunction, TrialFunction, Argument, + LinearVariationalProblem, LinearVariationalSolver) + from firedrake.adjoint_utils.blocks.solving import FORWARD, ADJOINT + + # If we build the adjoint form from the cached + # forward solve form then we can update exactly + # the same coefficients/boundary conditions. + nlvp = self._ad_solver_cache[FORWARD]._problem + + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # TODO: rewrite for adjoint not TLM + # We need gradient of output/input i.e. du/dm. + # We know F(u; m) = 0 and _total_ dF/dm = 0. + # Then for the _partial_ derivatives: + # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: + # (dF/du)*(du/dm) = -dF/dm + dFdu = derivative(F, u, TrialFunction(V)) + try: + dFdu_adj = adjoint(dFdu) + except ValueError: + # Try again without expanding derivatives, + # as dFdu might have been simplied to an empty Form + dFdu_adj = adjoint(dFdu, derivatives_expanded=True) + + dJdu = Cofunction(V.dual()) + adj_sol = Function(V) + + # Reuse the same bcs as the forward problem. + # TODO: Think about if we should use new bcs. + lvp = LinearVariationalProblem(dFdu_adj, dJdu, adj_sol, bcs=self._ad_bcs) + lvs = LinearVariationalSolver(lvp) + + self._ad_solver_cache[ADJOINT] = lvs + self._ad_adj_rhs = dJdu + + # Do all the symbolic work for calculating dJ/du up front + # so we only pay for the numeric calculations at run time. + replaced_adjs = [] + dFdm_adj_forms = [] + for m in self._ad_replaced_dependencies: + if isinstance(m, Function) and m.ufl_element().family() == "Real": + # TODO: Does this still need to be special cased? + adj = copy.deepcopy(m) + else: + adj = m.copy(deepcopy=True) + + replaced_adjs.append(adj) + + # Action of adjoint solution on dFdm + # TODO: Which of the two implementations should we use? + dFdm = derivative(-F, m, TrialFunction(m.function_space())) + + # 1. from previous cached implementation + dFdm = adjoint(dFdm) + if isinstance(dFdm, Argument): + # Corner case. Should be fixed more permanently upstream in UFL. + # See: https://github.com/FEniCS/ufl/issues/395 + dFdm = ufl.Action(dFdm, adj_sol) + else: + dFdm = dFdm * adj_sol + + # 2. from GenericSolveBlock + # if isinstance(dFdm, ufl.Form): + # dFdm = adjoint(dFdm) + # dFdm = action(dFdm, adj_sol) + # else: + # dFdm = dFdm(adj_sol) + + dFdm_adj_forms.append(dFdm) + + # To calculate the adjoint component of each DirichletBC + # we'll need the residual of the adjoint equation without + # any DirichletBC using the solution calculated with + # homogeneous DirichletBCs. + self._ad_adj_residual = dJdu - action(dFdu_adj, adj_sol) + + # We'll need to update the replaced_adj + # values and assemble the dFdm forms + self._ad_adj_dFdm_forms = dFdm_adj_forms + self._ad_replaced_adjs = replaced_adjs + @staticmethod def _ad_annotate_solve(solve): @wraps(solve) @@ -154,14 +272,21 @@ def wrapper(self, **kwargs): if len(self._ad_solver_cache) == 0: self._ad_cache_forward_solver() self._ad_cache_tlm_solver() + self._ad_cache_adj_solver() block = CachedSolverBlock(self._ad_problem.u, self._ad_bcs, self._ad_solver_cache, self._ad_replaced_dependencies, + self._ad_tlm_rhs, self._ad_replaced_tlms, self._ad_tlm_dFdm_forms, + + self._ad_adj_rhs, + self._ad_replaced_adjs, + self._ad_adj_dFdm_forms, + self._ad_adj_residual, ad_block_tag=self.ad_block_tag) for dep in self._ad_dependencies_to_add: diff --git a/tests/firedrake/adjoint/test_hessian.py b/tests/firedrake/adjoint/test_hessian.py index 072090b5ea..41242e7bd7 100644 --- a/tests/firedrake/adjoint/test_hessian.py +++ b/tests/firedrake/adjoint/test_hessian.py @@ -292,7 +292,6 @@ def Dt(u, u_, dt): h = rg.uniform(V) g = ic.copy(deepcopy=True) - print(f"{norm(h) = }") taylor = taylor_to_dict(Jhat, g, h) from pprint import pprint @@ -300,15 +299,3 @@ def Dt(u, u_, dt): assert min(taylor['R0']['Rate']) > 0.95, taylor['R0'] assert min(taylor['R1']['Rate']) > 1.95, taylor['R1'] assert min(taylor['R2']['Rate']) > 2.95, taylor['R2'] - - # J.block_variable.adj_value = 1.0 - # ic.block_variable.tlm_value = h - # tape.evaluate_adj() - # tape.evaluate_tlm() - - # J.block_variable.hessian_value = 0 - # tape.evaluate_hessian() - - # dJdm = J.block_variable.tlm_value - # Hm = ic.block_variable.hessian_value.dat.inner(h.dat) - # assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.9 diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index 957226a51f..ea88ef6423 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -27,10 +27,10 @@ def forward(ic, dt, nt, bc_arg=None): if bc_arg: bc_val = bc_arg.copy(deepcopy=True) - bc = DirichletBC(V, bc_val, 1) - # bc.apply(ic) + bcs = [DirichletBC(V, bc_val, 1), + DirichletBC(V, 0, 2)] else: - bc = None + bcs = None nu = Function(dt.function_space()).assign(0.1) @@ -42,7 +42,7 @@ def forward(ic, dt, nt, bc_arg=None): + dt*u1*u1.dx(0)*v + dt*nu*u1.dx(0)*v.dx(0))*dx - problem = NonlinearVariationalProblem(F, u1, bcs=bc) + problem = NonlinearVariationalProblem(F, u1, bcs=bcs) solver = NonlinearVariationalSolver(problem) u1.assign(ic) @@ -74,7 +74,7 @@ def test_nlvs_adjoint(control_type, bc_type): V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 2 + nt = 4 dt = Function(R).assign(0.1) ic = Function(V).interpolate(cos(2*pi*x)) @@ -88,7 +88,7 @@ def test_nlvs_adjoint(control_type, bc_type): bc_arg = Function(R).assign(1.) bc_arg0 = bc_arg.copy(deepcopy=True) else: - raise ValueError(f"Unrecognised {bc_type = }") + raise ValueError(f"Unrecognised {bc_type=}") if control_type == 'ic_control': control = ic0 @@ -97,7 +97,7 @@ def test_nlvs_adjoint(control_type, bc_type): elif control_type == 'bc_control': control = bc_arg0 else: - raise ValueError(f"Unrecognised {control_type = }") + raise ValueError(f"Unrecognised {control_type=}") print("record tape") continue_annotation() @@ -135,10 +135,16 @@ def test_nlvs_adjoint(control_type, bc_type): assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 # tlm + Jhat(m) assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + # adjoint + Jhat(m) + assert taylor_test(Jhat, m, h) > 1.95 + if __name__ == "__main__": - ctype = "ic" - print(f"Control type: {ctype}") - test_nlvs_adjoint(f"{ctype}_control") + control_type = "ic_control" + bc_type = "neumann_bc" + print(f"{control_type=} | {bc_type=}") + test_nlvs_adjoint(control_type, bc_type) From 7065b99ff72508c62a5fefa717a67cd6e58becf4 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 9 Oct 2025 13:33:29 +0100 Subject: [PATCH 06/12] fix adjoint integer reassignment warnings --- tests/firedrake/adjoint/test_solving.py | 18 +++++++++--------- tests/firedrake/adjoint/test_tlm.py | 12 ++++++------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/firedrake/adjoint/test_solving.py b/tests/firedrake/adjoint/test_solving.py index 5ce9b120ce..7997fe39f5 100644 --- a/tests/firedrake/adjoint/test_solving.py +++ b/tests/firedrake/adjoint/test_solving.py @@ -33,7 +33,7 @@ def test_linear_problem(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = TrialFunction(V) u_ = Function(V) @@ -60,7 +60,7 @@ def test_singular_linear_problem(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "CG", 1) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = TrialFunction(V) u_ = Function(V) @@ -85,7 +85,7 @@ def test_nonlinear_problem(pre_apply_bcs, rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -116,7 +116,7 @@ def test_mixed_boundary(rg): g1 = Constant(2) g2 = Constant(1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(f): a = f*inner(grad(u), grad(v))*dx @@ -165,7 +165,7 @@ def xtest_wrt_function_dirichlet_boundary(): g1 = Constant(2) g2 = Constant(1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(bc): a = inner(grad(u), grad(v))*dx @@ -195,7 +195,7 @@ def test_wrt_function_neumann_boundary(): g1 = Function(R, val=2) g2 = Function(R, val=1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(g1): a = inner(grad(u), grad(v))*dx @@ -247,7 +247,7 @@ def test_wrt_constant_neumann_boundary(): g1 = Function(R, val=2) g2 = Function(R, val=1) - f = Function(V).assign(10) + f = Function(V).assign(10.) def J(g1): a = inner(grad(u), grad(v))*dx @@ -283,7 +283,7 @@ def test_time_dependent(): f = Function(R, val=1) def J(f): - u_1 = Function(V).assign(1) + u_1 = Function(V).assign(1.) a = u_1*u*v*dx + dt*f*inner(grad(u), grad(v))*dx L = u_1*v*dx @@ -340,7 +340,7 @@ def _test_adjoint_function_boundary(J, bc, f): set_working_tape(tape) V = f.function_space() - h = Function(V).assign(1) + h = Function(V).assign(1.) g = Function(V) eps_ = [0.4/2.0**i for i in range(4)] residuals = [] diff --git a/tests/firedrake/adjoint/test_tlm.py b/tests/firedrake/adjoint/test_tlm.py index 723f84e481..b9123ad9bd 100644 --- a/tests/firedrake/adjoint/test_tlm.py +++ b/tests/firedrake/adjoint/test_tlm.py @@ -36,7 +36,7 @@ def test_tlm_assemble(rg): set_working_tape(tape) mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(5) + f = Function(V).assign(5.) u = TrialFunction(V) v = TestFunction(V) @@ -66,7 +66,7 @@ def test_tlm_bc(): V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) c = Function(R, val=1) - f = Function(V).assign(1) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -88,8 +88,8 @@ def test_tlm_func(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - c = Function(V).assign(1) - f = Function(V).assign(1) + c = Function(V).assign(1.) + f = Function(V).assign(1.) u = Function(V) v = TestFunction(V) @@ -130,9 +130,9 @@ def test_time_dependent(solve_type, rg): # Some variables T = 0.5 dt = 0.1 - f = Function(V).assign(1) + f = Function(V).assign(1.) - u_1 = Function(V).assign(1) + u_1 = Function(V).assign(1.) control = Control(u_1) a = u_1 * u * v * dx + dt * f * inner(grad(u), grad(v)) * dx From 52258b412fae7cba6cc066ee327fcc66023c4340 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 9 Oct 2025 13:34:14 +0100 Subject: [PATCH 07/12] remove unused nlvs block adj variables --- firedrake/adjoint_utils/blocks/solving.py | 3 +- firedrake/adjoint_utils/variational_solver.py | 34 +++++-------------- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index a8723c0fce..7a32817bc2 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -43,7 +43,7 @@ class CachedSolverBlock(Block): def __init__(self, func, bcs, cached_solvers, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, - adj_rhs, replaced_adjs, adj_dFdm_forms, adj_residual, + adj_rhs, adj_dFdm_forms, adj_residual, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) @@ -57,7 +57,6 @@ def __init__(self, func, bcs, cached_solvers, self.tlm_dFdm_forms = tlm_dFdm_forms self.adj_rhs = adj_rhs - self.replaced_adjs = replaced_adjs self.adj_dFdm_forms = adj_dFdm_forms self.adj_residual = adj_residual diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 0fde2bea3f..a8c9bed0fc 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -72,11 +72,7 @@ def _ad_cache_forward_solver(self): F = problem.F replace_map = {} for old_coeff in F.coefficients(): - if isinstance(old_coeff, Function) and old_coeff.ufl_element().family() == "Real": - # TODO: Does this still need to be special cased? - new_coeff = copy.deepcopy(old_coeff) - else: - new_coeff = old_coeff.copy(deepcopy=True) + new_coeff = old_coeff.copy(deepcopy=True) replace_map[old_coeff] = new_coeff # We need a handle to the new Function being @@ -100,6 +96,7 @@ def _ad_cache_forward_solver(self): ] # This NLVS will be used to recompute the solve. + # TODO: solver_parameters nlvp = NonlinearVariationalProblem(Fnew, unew, bcs=bcs_new) nlvs = NonlinearVariationalSolver(nlvp) @@ -112,8 +109,6 @@ def _ad_cache_forward_solver(self): self._ad_bcs = bcs_new self._ad_solver_cache[FORWARD] = nlvs - print(f"{replace_map.values()=}") - def _ad_cache_tlm_solver(self): from firedrake import ( Function, Cofunction, derivative, TrialFunction, @@ -140,6 +135,7 @@ def _ad_cache_tlm_solver(self): # Reuse the same bcs as the forward problem. # TODO: Think about if we should use new bcs. + # TODO: solver_parameters lvp = LinearVariationalProblem(dFdu, dFdm, dudm, bcs=self._ad_bcs) lvs = LinearVariationalSolver(lvp) @@ -151,12 +147,7 @@ def _ad_cache_tlm_solver(self): replaced_tlms = [] dFdm_tlm_forms = [] for m in self._ad_replaced_dependencies: - if isinstance(m, Function) and m.ufl_element().family() == "Real": - # TODO: Does this still need to be special cased? - mtlm = copy.deepcopy(m) - else: - mtlm = m.copy(deepcopy=True) - + mtlm = m.copy(deepcopy=True) replaced_tlms.append(mtlm) dFdm = derivative(-F, m, mtlm) @@ -198,11 +189,13 @@ def _ad_cache_adj_solver(self): # as dFdu might have been simplied to an empty Form dFdu_adj = adjoint(dFdu, derivatives_expanded=True) + # This will be the rhs of the adjoint problem dJdu = Cofunction(V.dual()) adj_sol = Function(V) # Reuse the same bcs as the forward problem. # TODO: Think about if we should use new bcs. + # TODO: solver_parameters lvp = LinearVariationalProblem(dFdu_adj, dJdu, adj_sol, bcs=self._ad_bcs) lvs = LinearVariationalSolver(lvp) @@ -211,17 +204,8 @@ def _ad_cache_adj_solver(self): # Do all the symbolic work for calculating dJ/du up front # so we only pay for the numeric calculations at run time. - replaced_adjs = [] dFdm_adj_forms = [] for m in self._ad_replaced_dependencies: - if isinstance(m, Function) and m.ufl_element().family() == "Real": - # TODO: Does this still need to be special cased? - adj = copy.deepcopy(m) - else: - adj = m.copy(deepcopy=True) - - replaced_adjs.append(adj) - # Action of adjoint solution on dFdm # TODO: Which of the two implementations should we use? dFdm = derivative(-F, m, TrialFunction(m.function_space())) @@ -250,10 +234,9 @@ def _ad_cache_adj_solver(self): # homogeneous DirichletBCs. self._ad_adj_residual = dJdu - action(dFdu_adj, adj_sol) - # We'll need to update the replaced_adj - # values and assemble the dFdm forms + # We'll need to assemble these forms to calculate + # the adj_component for each dependency. self._ad_adj_dFdm_forms = dFdm_adj_forms - self._ad_replaced_adjs = replaced_adjs @staticmethod def _ad_annotate_solve(solve): @@ -284,7 +267,6 @@ def wrapper(self, **kwargs): self._ad_tlm_dFdm_forms, self._ad_adj_rhs, - self._ad_replaced_adjs, self._ad_adj_dFdm_forms, self._ad_adj_residual, ad_block_tag=self.ad_block_tag) From 6950356e4447376acc409c4a599b1e50e5f6b992 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 10 Oct 2025 12:28:41 +0100 Subject: [PATCH 08/12] cached nlvs block - hessian with ic controls --- firedrake/adjoint_utils/blocks/solving.py | 109 +++++++++++++++- firedrake/adjoint_utils/variational_solver.py | 118 ++++++++++++++++-- tests/firedrake/adjoint/test_nlvs.py | 19 ++- 3 files changed, 227 insertions(+), 19 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 7a32817bc2..9fe16ac8cc 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -44,6 +44,10 @@ def __init__(self, func, bcs, cached_solvers, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, adj_rhs, adj_dFdm_forms, adj_residual, + adj_sol, adj2_sol, tlm_output, + d2Fdu2_form, d2Fdmdu_forms, + dFdm_adj2_forms, d2Fdm2_adj_forms, + d2Fdudm_forms, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) @@ -60,6 +64,16 @@ def __init__(self, func, bcs, cached_solvers, self.adj_dFdm_forms = adj_dFdm_forms self.adj_residual = adj_residual + self.adj_sol = adj_sol + self.adj_sol_buf = adj_sol.copy(deepcopy=True) + self.adj2_sol = adj2_sol + self.tlm_output = tlm_output + self.d2Fdu2_form = d2Fdu2_form + self.d2Fdmdu_forms = d2Fdmdu_forms + self.dFdm_adj2_forms = dFdm_adj2_forms + self.d2Fdm2_adj_forms = d2Fdm2_adj_forms + self.d2Fdudm_forms = d2Fdudm_forms + def _coefficient_dependencies(self, dependencies=None): dependencies = dependencies or self.get_dependencies() return dependencies[:len(self.replaced_dependencies)] @@ -115,6 +129,10 @@ def update_adj_dependencies(self): # TODO: Anything to do here? pass + def update_hessian_dependencies(self): + # TODO: Anything else to do here? + self.update_tlm_dependencies() + def _compute_boundary(self, relevant_dependencies): return any(isinstance(dep.output, firedrake.DirichletBC) for _, dep in relevant_dependencies) @@ -147,7 +165,7 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - if dep.output is self.func: # Can't compute dependence on initial guess? + if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess continue self.tlm_rhs += firedrake.assemble(dFdm) @@ -174,6 +192,10 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): solver.solve() + # store this for Hessian computation + self.adj_sol.assign(adj_sol) + self.adj_sol_buf.assign(adj_sol) + if self._compute_boundary(relevant_dependencies): adj_sol_bc = firedrake.assemble(self.adj_residual) adj_sol_bc = adj_sol_bc.riesz_representation("l2") @@ -203,10 +225,80 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar return dFdm def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): - pass + self.adj_sol.assign(self.adj_sol_buf) + self.update_dependencies(use_output=True) + self.update_hessian_dependencies() + + hessian_input = hessian_inputs[0] + tlm_output = self.get_outputs()[0].tlm_value + + if hessian_input is None: + return + if tlm_output is None: + return + + # 1. Assemble rhs + + # hessian input contribution + self.adj_rhs.assign(hessian_input) + + # tlm_output contribution + self.tlm_output.assign(tlm_output) + self.adj_rhs -= firedrake.assemble(self.d2Fdu2_form) + + # tlm_input contribution + for d2Fdmdu, dep in zip(self.d2Fdmdu_forms, + self._coefficient_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + continue + if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess + continue + if len(d2Fdmdu.integrals()) > 0: + self.adj_rhs -= firedrake.assemble(d2Fdmdu) + + # 2. Solve adjoint system + for bc in self.bcs: + bc.homogenize() + + solver = self.cached_solvers[ADJOINT] + adj2_sol = solver._problem.u + adj2_sol.zero() + solver.solve() + + self.adj2_sol.assign(adj2_sol) + + prepared = { + "adj2_sol": adj2_sol.copy(deepcopy=True), + "adj2_sol_bc": None, + } + + return prepared def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): - pass + m = block_variable.output + + if m is self.func: # and not self.linear + return None + + relevant_d2Fdm2_forms = [] + for i, dep in relevant_dependencies: + if i >= len(self.replaced_dependencies): + continue + if dep.tlm_value is None: + continue + if dep.output is self.func: # and not self.linear + continue + relevant_d2Fdm2_forms.append(self.d2Fdm2_adj_forms[idx][i]) + + hessian_output = 0 + + for form in (self.d2Fdudm_forms[idx], + self.dFdm_adj2_forms[idx], + *relevant_d2Fdm2_forms): + if not form.empty(): + hessian_output = firedrake.assemble(-form) + + return hessian_output class GenericSolveBlock(Block): @@ -538,7 +630,10 @@ def _assemble_soa_eq_rhs(self, dFdu_form, adj_sol, hessian_input, d2Fdu2): elif not isinstance(c, firedrake.DirichletBC): dFdu_adj = firedrake.action(firedrake.adjoint(dFdu_form), adj_sol) - b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + # b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + bo_form = ufl.algorithms.expand_derivatives( + firedrake.derivative(dFdu_adj, c_rep, tlm_input)) + b_form += bo_form b_form = ufl.algorithms.expand_derivatives(b_form) if len(b_form.integrals()) > 0: @@ -566,6 +661,8 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, hessian_input = hessian_inputs[0] tlm_output = fwd_block_variable.tlm_value + self.adj_state = self.adj_state_buf.copy(deepcopy=True) + if hessian_input is None: return @@ -594,6 +691,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, r["adj_sol2_bdy"] = adj_sol2_bdy r["form"] = F_form r["adj_sol"] = adj_sol + return r def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, @@ -792,6 +890,8 @@ def __init__(self, equation, func, bcs, adj_cache, problem_J, self.problem_J = problem_J self.solver_kwargs = solver_kwargs + self.adj_state_buf = func.copy(deepcopy=True) + super().__init__(lhs, rhs, func, bcs, **{**solver_kwargs, **kwargs}) if self.problem_J is not None: @@ -900,6 +1000,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): ) adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) self.adj_state = adj_sol + self.adj_state_buf.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index a8c9bed0fc..731f0f76b4 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -3,7 +3,8 @@ from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock from firedrake.ufl_expr import derivative, adjoint, action -import ufl +from ufl import replace, Action +from ufl.algorithms import expand_derivatives class NonlinearVariationalProblemMixin: @@ -58,7 +59,7 @@ def wrapper(self, problem, *args, **kwargs): def _ad_cache_forward_solver(self): from firedrake import ( - Function, DirichletBC, + DirichletBC, NonlinearVariationalProblem, NonlinearVariationalSolver) from firedrake.adjoint_utils.blocks.solving import FORWARD @@ -77,7 +78,7 @@ def _ad_cache_forward_solver(self): # We need a handle to the new Function being # solved for so that we can create an NLVS. - Fnew = ufl.replace(F, replace_map) + Fnew = replace(F, replace_map) unew = replace_map[problem.u] # We also need to "replace" all the bcs in @@ -133,6 +134,8 @@ def _ad_cache_tlm_solver(self): dFdm = Cofunction(V.dual()) dudm = Function(V) + self._ad_dFdu = dFdu + # Reuse the same bcs as the forward problem. # TODO: Think about if we should use new bcs. # TODO: solver_parameters @@ -152,7 +155,7 @@ def _ad_cache_tlm_solver(self): dFdm = derivative(-F, m, mtlm) # TODO: Do we need expand_derivatives here? If so, why? - dFdm = ufl.algorithms.expand_derivatives(dFdm) + dFdm = expand_derivatives(dFdm) dFdm_tlm_forms.append(dFdm) # We'll need to update the replaced_tlm @@ -181,7 +184,7 @@ def _ad_cache_adj_solver(self): # Then for the _partial_ derivatives: # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: # (dF/du)*(du/dm) = -dF/dm - dFdu = derivative(F, u, TrialFunction(V)) + dFdu = self._ad_dFdu try: dFdu_adj = adjoint(dFdu) except ValueError: @@ -189,6 +192,8 @@ def _ad_cache_adj_solver(self): # as dFdu might have been simplied to an empty Form dFdu_adj = adjoint(dFdu, derivatives_expanded=True) + self._ad_dFdu_adj = dFdu_adj + # This will be the rhs of the adjoint problem dJdu = Cofunction(V.dual()) adj_sol = Function(V) @@ -215,7 +220,7 @@ def _ad_cache_adj_solver(self): if isinstance(dFdm, Argument): # Corner case. Should be fixed more permanently upstream in UFL. # See: https://github.com/FEniCS/ufl/issues/395 - dFdm = ufl.Action(dFdm, adj_sol) + dFdm = Action(dFdm, adj_sol) else: dFdm = dFdm * adj_sol @@ -238,8 +243,88 @@ def _ad_cache_adj_solver(self): # the adj_component for each dependency. self._ad_adj_dFdm_forms = dFdm_adj_forms + def _ad_cache_hessian_solver(self): + from firedrake import ( + Function, TestFunction) + from firedrake.adjoint_utils.blocks.solving import FORWARD + + nlvp = self._ad_solver_cache[FORWARD]._problem + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # 1. Forms to calculate rhs of Hessian solve + + # Calculate d^2F/du^2 * du/dm * dm + # where dm is direction for tlm action so du/dm * dm is tlm output + dFdu = self._ad_dFdu + tlm_output = Function(V) + d2Fdu2 = expand_derivatives( + derivative(dFdu, u, tlm_output)) + + self._ad_tlm_output = tlm_output + + adj_sol = Function(V) + self._ad_adj_sol = adj_sol + + # Contribution from tlm_output + if len(d2Fdu2.integrals()) > 0: + d2Fdu2_form = action(adjoint(d2Fdu2), adj_sol) + else: + d2Fdu2_form = d2Fdu2 + self._ad_d2Fdu2_form = d2Fdu2_form + + # Contributions from each tlm_input + dFdu_adj = action(self._ad_dFdu_adj, adj_sol) + d2Fdmdu_forms = [] + for m, dm in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdmdu = expand_derivatives( + derivative(dFdu_adj, m, dm)) + + d2Fdmdu_forms.append(d2Fdmdu) + + self._ad_d2Fdmdu_forms = d2Fdmdu_forms + + # 2. Forms to calculate contribution from each control + adj2_sol = Function(V) + self._ad_adj2_sol = adj2_sol + + Fadj = action(F, adj_sol) + Fadj2 = action(F, adj2_sol) + + dFdm_adj2_forms = [] + d2Fdm2_adj_forms = [] + d2Fdudm_forms = [] + for m in self._ad_replaced_dependencies: + dm = TestFunction(m.function_space()) + dFdm_adj2 = expand_derivatives( + derivative(Fadj2, m, dm)) + + dFdm_adj2_forms.append(dFdm_adj2) + + dFdm_adj = derivative(Fadj, m, dm) + + d2Fdudm = expand_derivatives( + derivative(dFdm_adj, u, tlm_output)) + + d2Fdudm_forms.append(d2Fdudm) + + d2Fdm2_adj_forms_k = [] + for m2, dm2 in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdm2_adj = expand_derivatives( + derivative(dFdm_adj, m2, dm2)) + d2Fdm2_adj_forms_k.append(d2Fdm2_adj) + + d2Fdm2_adj_forms.append(d2Fdm2_adj_forms_k) + + self._ad_dFdm_adj2_forms = dFdm_adj2_forms + self._ad_d2Fdm2_adj_forms = d2Fdm2_adj_forms + self._ad_d2Fdudm_forms = d2Fdudm_forms + @staticmethod - def _ad_annotate_solve(solve): + def _ad_annotate_solve_new(solve): @wraps(solve) def wrapper(self, **kwargs): """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the @@ -256,6 +341,7 @@ def wrapper(self, **kwargs): self._ad_cache_forward_solver() self._ad_cache_tlm_solver() self._ad_cache_adj_solver() + self._ad_cache_hessian_solver() block = CachedSolverBlock(self._ad_problem.u, self._ad_bcs, @@ -269,6 +355,16 @@ def wrapper(self, **kwargs): self._ad_adj_rhs, self._ad_adj_dFdm_forms, self._ad_adj_residual, + + self._ad_adj_sol, + self._ad_adj2_sol, + self._ad_tlm_output, + self._ad_d2Fdu2_form, + self._ad_d2Fdmdu_forms, + self._ad_dFdm_adj2_forms, + self._ad_d2Fdm2_adj_forms, + self._ad_d2Fdudm_forms, + ad_block_tag=self.ad_block_tag) for dep in self._ad_dependencies_to_add: @@ -287,7 +383,7 @@ def wrapper(self, **kwargs): return wrapper @staticmethod - def _ad_annotate_solve_old(solve): + def _ad_annotate_solve(solve): @wraps(solve) def wrapper(self, **kwargs): """To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the @@ -357,10 +453,10 @@ def _ad_problem_clone(self, problem, dependencies): from firedrake import NonlinearVariationalProblem _ad_count_map, J_replace_map, F_replace_map = self._build_count_map( problem.J, dependencies, F=problem.F) - nlvp = NonlinearVariationalProblem(ufl.replace(problem.F, F_replace_map), + nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), F_replace_map[problem.u_restrict], bcs=problem.bcs, - J=ufl.replace(problem.J, J_replace_map)) + J=replace(problem.J, J_replace_map)) nlvp.is_linear = problem.is_linear nlvp._constant_jacobian = problem._constant_jacobian nlvp._ad_count_map_update(_ad_count_map) @@ -385,7 +481,7 @@ def _ad_adj_lvs_problem(self, block, adj_F): _ad_count_map, J_replace_map, _ = self._build_count_map( adj_F, block._dependencies) lvp = LinearVariationalProblem( - ufl.replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, bcs=tmp_problem.bcs, constant_jacobian=self._ad_problem._constant_jacobian) lvp._ad_count_map_update(_ad_count_map) diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index ea88ef6423..eaeef7bc0e 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -51,8 +51,8 @@ def forward(ic, dt, nt, bc_arg=None): u0.assign(u1) solver.solve() nu += dt - if bc_arg: - bc_val.assign(bc_val + dt) + # if bc_arg: + # bc_val.assign(bc_val + dt) J = assemble(u1*u1*dx) return J @@ -68,13 +68,13 @@ def test_nlvs_adjoint(control_type, bc_type): if control_type == 'bc_control' and bc_type == 'neumann_bc': pytest.skip("Cannot use Neumann BCs as control") - mesh = UnitIntervalMesh(10) + mesh = UnitIntervalMesh(6) x, = SpatialCoordinate(mesh) V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 4 + nt = 2 dt = Function(R).assign(0.1) ic = Function(V).interpolate(cos(2*pi*x)) @@ -135,13 +135,24 @@ def test_nlvs_adjoint(control_type, bc_type): assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 # tlm + print("tlm test") Jhat(m) assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 # adjoint + print("adjoint test") Jhat(m) assert taylor_test(Jhat, m, h) > 1.95 + # hessian + print("hessian test") + Jhat(m) + taylor = taylor_to_dict(Jhat, m, h) + from pprint import pprint + pprint(taylor) + + assert min(taylor['R2']['Rate']) > 2.95 + if __name__ == "__main__": control_type = "ic_control" From 0809b769141a93eae89f1d3b14ab1f1c4a68eb38 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 10 Oct 2025 12:28:41 +0100 Subject: [PATCH 09/12] cached nlvs block - hessian with ic controls --- firedrake/adjoint_utils/blocks/solving.py | 109 ++++++++++++++++- firedrake/adjoint_utils/variational_solver.py | 114 ++++++++++++++++-- tests/firedrake/adjoint/test_nlvs.py | 15 ++- 3 files changed, 223 insertions(+), 15 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 7a32817bc2..3f34187357 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -44,6 +44,10 @@ def __init__(self, func, bcs, cached_solvers, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, adj_rhs, adj_dFdm_forms, adj_residual, + adj_sol, adj2_sol, tlm_output, + d2Fdu2_form, d2Fdmdu_forms, + dFdm_adj2_forms, d2Fdm2_adj_forms, + d2Fdudm_forms, ad_block_tag=None): super().__init__(ad_block_tag=ad_block_tag) @@ -60,6 +64,16 @@ def __init__(self, func, bcs, cached_solvers, self.adj_dFdm_forms = adj_dFdm_forms self.adj_residual = adj_residual + self.adj_sol = adj_sol + self.adj_sol_buf = adj_sol.copy(deepcopy=True) + self.adj2_sol = adj2_sol + self.tlm_output = tlm_output + self.d2Fdu2_form = d2Fdu2_form + self.d2Fdmdu_forms = d2Fdmdu_forms + self.dFdm_adj2_forms = dFdm_adj2_forms + self.d2Fdm2_adj_forms = d2Fdm2_adj_forms + self.d2Fdudm_forms = d2Fdudm_forms + def _coefficient_dependencies(self, dependencies=None): dependencies = dependencies or self.get_dependencies() return dependencies[:len(self.replaced_dependencies)] @@ -115,6 +129,10 @@ def update_adj_dependencies(self): # TODO: Anything to do here? pass + def update_hessian_dependencies(self): + # TODO: Anything else to do here? + self.update_tlm_dependencies() + def _compute_boundary(self, relevant_dependencies): return any(isinstance(dep.output, firedrake.DirichletBC) for _, dep in relevant_dependencies) @@ -147,7 +165,7 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - if dep.output is self.func: # Can't compute dependence on initial guess? + if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess continue self.tlm_rhs += firedrake.assemble(dFdm) @@ -174,6 +192,10 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): solver.solve() + # store this for Hessian computation + self.adj_sol.assign(adj_sol) + self.adj_sol_buf.assign(adj_sol) + if self._compute_boundary(relevant_dependencies): adj_sol_bc = firedrake.assemble(self.adj_residual) adj_sol_bc = adj_sol_bc.riesz_representation("l2") @@ -203,10 +225,80 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepar return dFdm def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): - pass + self.adj_sol.assign(self.adj_sol_buf) + self.update_dependencies(use_output=True) + self.update_hessian_dependencies() + + hessian_input = hessian_inputs[0] + tlm_output = self.get_outputs()[0].tlm_value + + if hessian_input is None: + return + if tlm_output is None: + return + + # 1. Assemble rhs + + # hessian input contribution + self.adj_rhs.assign(hessian_input) + + # tlm_output contribution + self.tlm_output.assign(tlm_output) + self.adj_rhs -= firedrake.assemble(self.d2Fdu2_form) + + # tlm_input contribution + for d2Fdmdu, dep in zip(self.d2Fdmdu_forms, + self._coefficient_dependencies()): + if dep.tlm_value is None: # This dependency doesn't depend on the controls + continue + if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess + continue + if len(d2Fdmdu.integrals()) > 0: + self.adj_rhs -= firedrake.assemble(d2Fdmdu) + + # 2. Solve adjoint system + for bc in self.bcs: + bc.homogenize() + + solver = self.cached_solvers[ADJOINT] + adj2_sol = solver._problem.u + adj2_sol.zero() + solver.solve() + + self.adj2_sol.assign(adj2_sol) + + prepared = { + "adj2_sol": adj2_sol.copy(deepcopy=True), + "adj2_sol_bc": None, + } + + return prepared def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): - pass + m = block_variable.output + + if m is self.func: # and not self.linear + return None + + relevant_d2Fdm2_forms = [] + for i, dep in relevant_dependencies: + if i >= len(self.replaced_dependencies): + continue + if dep.tlm_value is None: + continue + if dep.output is self.func: # and not self.linear + continue + relevant_d2Fdm2_forms.append(self.d2Fdm2_adj_forms[idx][i]) + + hessian_output = 0 + + for form in (self.d2Fdudm_forms[idx], + self.dFdm_adj2_forms[idx], + *relevant_d2Fdm2_forms): + if not form.empty(): + hessian_output += firedrake.assemble(-form) + + return hessian_output class GenericSolveBlock(Block): @@ -538,7 +630,10 @@ def _assemble_soa_eq_rhs(self, dFdu_form, adj_sol, hessian_input, d2Fdu2): elif not isinstance(c, firedrake.DirichletBC): dFdu_adj = firedrake.action(firedrake.adjoint(dFdu_form), adj_sol) - b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + # b_form += firedrake.derivative(dFdu_adj, c_rep, tlm_input) + bo_form = ufl.algorithms.expand_derivatives( + firedrake.derivative(dFdu_adj, c_rep, tlm_input)) + b_form += bo_form b_form = ufl.algorithms.expand_derivatives(b_form) if len(b_form.integrals()) > 0: @@ -566,6 +661,8 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, hessian_input = hessian_inputs[0] tlm_output = fwd_block_variable.tlm_value + self.adj_state = self.adj_state_buf.copy(deepcopy=True) + if hessian_input is None: return @@ -594,6 +691,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, r["adj_sol2_bdy"] = adj_sol2_bdy r["form"] = F_form r["adj_sol"] = adj_sol + return r def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, @@ -792,6 +890,8 @@ def __init__(self, equation, func, bcs, adj_cache, problem_J, self.problem_J = problem_J self.solver_kwargs = solver_kwargs + self.adj_state_buf = func.copy(deepcopy=True) + super().__init__(lhs, rhs, func, bcs, **{**solver_kwargs, **kwargs}) if self.problem_J is not None: @@ -900,6 +1000,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): ) adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) self.adj_state = adj_sol + self.adj_state_buf.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index a8c9bed0fc..2d2e739191 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -3,7 +3,8 @@ from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock from firedrake.ufl_expr import derivative, adjoint, action -import ufl +from ufl import replace, Action +from ufl.algorithms import expand_derivatives class NonlinearVariationalProblemMixin: @@ -58,7 +59,7 @@ def wrapper(self, problem, *args, **kwargs): def _ad_cache_forward_solver(self): from firedrake import ( - Function, DirichletBC, + DirichletBC, NonlinearVariationalProblem, NonlinearVariationalSolver) from firedrake.adjoint_utils.blocks.solving import FORWARD @@ -77,7 +78,7 @@ def _ad_cache_forward_solver(self): # We need a handle to the new Function being # solved for so that we can create an NLVS. - Fnew = ufl.replace(F, replace_map) + Fnew = replace(F, replace_map) unew = replace_map[problem.u] # We also need to "replace" all the bcs in @@ -133,6 +134,8 @@ def _ad_cache_tlm_solver(self): dFdm = Cofunction(V.dual()) dudm = Function(V) + self._ad_dFdu = dFdu + # Reuse the same bcs as the forward problem. # TODO: Think about if we should use new bcs. # TODO: solver_parameters @@ -152,7 +155,7 @@ def _ad_cache_tlm_solver(self): dFdm = derivative(-F, m, mtlm) # TODO: Do we need expand_derivatives here? If so, why? - dFdm = ufl.algorithms.expand_derivatives(dFdm) + dFdm = expand_derivatives(dFdm) dFdm_tlm_forms.append(dFdm) # We'll need to update the replaced_tlm @@ -181,7 +184,7 @@ def _ad_cache_adj_solver(self): # Then for the _partial_ derivatives: # (dF/du)*(du/dm) + dF/dm = 0 so we calculate: # (dF/du)*(du/dm) = -dF/dm - dFdu = derivative(F, u, TrialFunction(V)) + dFdu = self._ad_dFdu try: dFdu_adj = adjoint(dFdu) except ValueError: @@ -189,6 +192,8 @@ def _ad_cache_adj_solver(self): # as dFdu might have been simplied to an empty Form dFdu_adj = adjoint(dFdu, derivatives_expanded=True) + self._ad_dFdu_adj = dFdu_adj + # This will be the rhs of the adjoint problem dJdu = Cofunction(V.dual()) adj_sol = Function(V) @@ -215,7 +220,7 @@ def _ad_cache_adj_solver(self): if isinstance(dFdm, Argument): # Corner case. Should be fixed more permanently upstream in UFL. # See: https://github.com/FEniCS/ufl/issues/395 - dFdm = ufl.Action(dFdm, adj_sol) + dFdm = Action(dFdm, adj_sol) else: dFdm = dFdm * adj_sol @@ -238,6 +243,86 @@ def _ad_cache_adj_solver(self): # the adj_component for each dependency. self._ad_adj_dFdm_forms = dFdm_adj_forms + def _ad_cache_hessian_solver(self): + from firedrake import ( + Function, TestFunction) + from firedrake.adjoint_utils.blocks.solving import FORWARD + + nlvp = self._ad_solver_cache[FORWARD]._problem + F = nlvp.F + u = nlvp.u + V = u.function_space() + + # 1. Forms to calculate rhs of Hessian solve + + # Calculate d^2F/du^2 * du/dm * dm + # where dm is direction for tlm action so du/dm * dm is tlm output + dFdu = self._ad_dFdu + tlm_output = Function(V) + d2Fdu2 = expand_derivatives( + derivative(dFdu, u, tlm_output)) + + self._ad_tlm_output = tlm_output + + adj_sol = Function(V) + self._ad_adj_sol = adj_sol + + # Contribution from tlm_output + if len(d2Fdu2.integrals()) > 0: + d2Fdu2_form = action(adjoint(d2Fdu2), adj_sol) + else: + d2Fdu2_form = d2Fdu2 + self._ad_d2Fdu2_form = d2Fdu2_form + + # Contributions from each tlm_input + dFdu_adj = action(self._ad_dFdu_adj, adj_sol) + d2Fdmdu_forms = [] + for m, dm in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdmdu = expand_derivatives( + derivative(dFdu_adj, m, dm)) + + d2Fdmdu_forms.append(d2Fdmdu) + + self._ad_d2Fdmdu_forms = d2Fdmdu_forms + + # 2. Forms to calculate contribution from each control + adj2_sol = Function(V) + self._ad_adj2_sol = adj2_sol + + Fadj = action(F, adj_sol) + Fadj2 = action(F, adj2_sol) + + dFdm_adj2_forms = [] + d2Fdm2_adj_forms = [] + d2Fdudm_forms = [] + for m in self._ad_replaced_dependencies: + dm = TestFunction(m.function_space()) + dFdm_adj2 = expand_derivatives( + derivative(Fadj2, m, dm)) + + dFdm_adj2_forms.append(dFdm_adj2) + + dFdm_adj = derivative(Fadj, m, dm) + + d2Fdudm = expand_derivatives( + derivative(dFdm_adj, u, tlm_output)) + + d2Fdudm_forms.append(d2Fdudm) + + d2Fdm2_adj_forms_k = [] + for m2, dm2 in zip(self._ad_replaced_dependencies, + self._ad_replaced_tlms): + d2Fdm2_adj = expand_derivatives( + derivative(dFdm_adj, m2, dm2)) + d2Fdm2_adj_forms_k.append(d2Fdm2_adj) + + d2Fdm2_adj_forms.append(d2Fdm2_adj_forms_k) + + self._ad_dFdm_adj2_forms = dFdm_adj2_forms + self._ad_d2Fdm2_adj_forms = d2Fdm2_adj_forms + self._ad_d2Fdudm_forms = d2Fdudm_forms + @staticmethod def _ad_annotate_solve(solve): @wraps(solve) @@ -256,6 +341,7 @@ def wrapper(self, **kwargs): self._ad_cache_forward_solver() self._ad_cache_tlm_solver() self._ad_cache_adj_solver() + self._ad_cache_hessian_solver() block = CachedSolverBlock(self._ad_problem.u, self._ad_bcs, @@ -269,6 +355,16 @@ def wrapper(self, **kwargs): self._ad_adj_rhs, self._ad_adj_dFdm_forms, self._ad_adj_residual, + + self._ad_adj_sol, + self._ad_adj2_sol, + self._ad_tlm_output, + self._ad_d2Fdu2_form, + self._ad_d2Fdmdu_forms, + self._ad_dFdm_adj2_forms, + self._ad_d2Fdm2_adj_forms, + self._ad_d2Fdudm_forms, + ad_block_tag=self.ad_block_tag) for dep in self._ad_dependencies_to_add: @@ -357,10 +453,10 @@ def _ad_problem_clone(self, problem, dependencies): from firedrake import NonlinearVariationalProblem _ad_count_map, J_replace_map, F_replace_map = self._build_count_map( problem.J, dependencies, F=problem.F) - nlvp = NonlinearVariationalProblem(ufl.replace(problem.F, F_replace_map), + nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), F_replace_map[problem.u_restrict], bcs=problem.bcs, - J=ufl.replace(problem.J, J_replace_map)) + J=replace(problem.J, J_replace_map)) nlvp.is_linear = problem.is_linear nlvp._constant_jacobian = problem._constant_jacobian nlvp._ad_count_map_update(_ad_count_map) @@ -385,7 +481,7 @@ def _ad_adj_lvs_problem(self, block, adj_F): _ad_count_map, J_replace_map, _ = self._build_count_map( adj_F, block._dependencies) lvp = LinearVariationalProblem( - ufl.replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, bcs=tmp_problem.bcs, constant_jacobian=self._ad_problem._constant_jacobian) lvp._ad_count_map_update(_ad_count_map) diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index ea88ef6423..99fe61bbfc 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -68,13 +68,13 @@ def test_nlvs_adjoint(control_type, bc_type): if control_type == 'bc_control' and bc_type == 'neumann_bc': pytest.skip("Cannot use Neumann BCs as control") - mesh = UnitIntervalMesh(10) + mesh = UnitIntervalMesh(6) x, = SpatialCoordinate(mesh) V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 4 + nt = 2 dt = Function(R).assign(0.1) ic = Function(V).interpolate(cos(2*pi*x)) @@ -135,13 +135,24 @@ def test_nlvs_adjoint(control_type, bc_type): assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 # tlm + print("tlm test") Jhat(m) assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 # adjoint + print("adjoint test") Jhat(m) assert taylor_test(Jhat, m, h) > 1.95 + # hessian + print("hessian test") + Jhat(m) + taylor = taylor_to_dict(Jhat, m, h) + from pprint import pprint + pprint(taylor) + + assert min(taylor['R2']['Rate']) > 2.95 + if __name__ == "__main__": control_type = "ic_control" From 4d93b0a60574253f15e0cae64e90ca12fbb93d0c Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Sat, 11 Oct 2025 16:36:34 +0100 Subject: [PATCH 10/12] cached nlvs block - hessian with bc updates and controls --- firedrake/adjoint_utils/blocks/solving.py | 60 ++++++++++++++--------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 3f34187357..a3e2d99540 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -176,34 +176,44 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar result = solver._problem.u.copy(deepcopy=True) return result - def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - self.update_dependencies(use_output=True) - self.update_adj_dependencies() - + def solve_adj_equation(self, rhs, compute_boundary): for bc in self.bcs: bc.homogenize() solver = self.cached_solvers[ADJOINT] adj_sol = solver._problem.u - dJdu = adj_inputs[0] - self.adj_rhs.assign(dJdu) + self.adj_rhs.assign(rhs) adj_sol.zero() solver.solve() - # store this for Hessian computation - self.adj_sol.assign(adj_sol) - self.adj_sol_buf.assign(adj_sol) - - if self._compute_boundary(relevant_dependencies): + if compute_boundary: adj_sol_bc = firedrake.assemble(self.adj_residual) adj_sol_bc = adj_sol_bc.riesz_representation("l2") else: adj_sol_bc = None + return adj_sol, adj_sol_bc + + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): + self.update_dependencies(use_output=True) + self.update_adj_dependencies() + + dJdu = adj_inputs[0] + + compute_boundary = self._compute_boundary(relevant_dependencies) + + adj_sol, adj_sol_bc = self.solve_adj_equation(dJdu, compute_boundary) + + # store adj_sol for Hessian computation later. + # self.adj_sol is shared between all blocks that this NLVS + # generates so we can't store it there. Instead store it + # in self.adj_sol_buf which is owned by this block only. + self.adj_sol_buf.assign(adj_sol) + prepared = { - "adj_sol": adj_sol, + "adj_sol": adj_sol.copy(deepcopy=True), "adj_sol_bc": adj_sol_bc } return prepared @@ -240,11 +250,11 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_ # 1. Assemble rhs # hessian input contribution - self.adj_rhs.assign(hessian_input) + hessian_rhs = hessian_input.copy(deepcopy=True) # tlm_output contribution self.tlm_output.assign(tlm_output) - self.adj_rhs -= firedrake.assemble(self.d2Fdu2_form) + hessian_rhs -= firedrake.assemble(self.d2Fdu2_form) # tlm_input contribution for d2Fdmdu, dep in zip(self.d2Fdmdu_forms, @@ -254,22 +264,17 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_ if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess continue if len(d2Fdmdu.integrals()) > 0: - self.adj_rhs -= firedrake.assemble(d2Fdmdu) + hessian_rhs -= firedrake.assemble(d2Fdmdu) # 2. Solve adjoint system - for bc in self.bcs: - bc.homogenize() - - solver = self.cached_solvers[ADJOINT] - adj2_sol = solver._problem.u - adj2_sol.zero() - solver.solve() + compute_boundary = self._compute_boundary(relevant_dependencies) + adj2_sol, adj2_sol_bc = self.solve_adj_equation(hessian_rhs, compute_boundary) self.adj2_sol.assign(adj2_sol) prepared = { "adj2_sol": adj2_sol.copy(deepcopy=True), - "adj2_sol_bc": None, + "adj2_sol_bc": adj2_sol_bc, } return prepared @@ -280,9 +285,16 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v if m is self.func: # and not self.linear return None + if isinstance(m, firedrake.DirichletBC): + bc = block_variable.output + adj2_sol_bc = prepared["adj2_sol_bc"] + return bc.reconstruct( + g=extract_subfunction(adj2_sol_bc, bc.function_space()) + ) + relevant_d2Fdm2_forms = [] for i, dep in relevant_dependencies: - if i >= len(self.replaced_dependencies): + if i >= len(self._coefficient_dependencies()): continue if dep.tlm_value is None: continue From 364dc648cdba23c3fdd5e28cf92ca83dd6a1fb73 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Mon, 13 Oct 2025 08:20:31 +0100 Subject: [PATCH 11/12] cached nlvs block - switches for linear jacobian and adj_state_buf for generic solve block --- firedrake/adjoint_utils/blocks/solving.py | 18 ++++++++++------- firedrake/adjoint_utils/variational_solver.py | 1 + tests/firedrake/adjoint/test_hessian.py | 20 +++++++++---------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index a3e2d99540..c819dd23e7 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -41,7 +41,7 @@ class SolverType(Enum): class CachedSolverBlock(Block): def __init__(self, func, bcs, cached_solvers, - replaced_dependencies, + is_linear, replaced_dependencies, tlm_rhs, replaced_tlms, tlm_dFdm_forms, adj_rhs, adj_dFdm_forms, adj_residual, adj_sol, adj2_sol, tlm_output, @@ -55,6 +55,7 @@ def __init__(self, func, bcs, cached_solvers, self.bcs = bcs self.cached_solvers = cached_solvers self.replaced_dependencies = replaced_dependencies + self.is_linear = is_linear self.tlm_rhs = tlm_rhs self.replaced_tlms = replaced_tlms @@ -112,7 +113,7 @@ def update_tlm_dependencies(self): """ for replaced_dep, dep in zip(self.replaced_tlms, self._coefficient_dependencies()): - if dep.output == self.func: # TODO: and not self.linear + if dep.output == self.func and not self.is_linear: continue if dep.tlm_value is None: # This dependency doesn't depend on the controls continue @@ -165,7 +166,7 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepar for dFdm, dep in zip(self.tlm_dFdm_forms, self.get_dependencies()): if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess + if dep.output is self.func and not self.is_linear: # Can't compute dependence on initial guess continue self.tlm_rhs += firedrake.assemble(dFdm) @@ -219,7 +220,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): return prepared def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): - if block_variable.output == self.func: # and not self.linear + if block_variable.output == self.func and not self.is_linear: return None if isinstance(block_variable.output, firedrake.DirichletBC): @@ -261,7 +262,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_ self._coefficient_dependencies()): if dep.tlm_value is None: # This dependency doesn't depend on the controls continue - if dep.output is self.func: # and not self.linear # Can't compute dependence on initial guess + if dep.output is self.func and not self.is_linear: # Can't compute dependence on initial guess continue if len(d2Fdmdu.integrals()) > 0: hessian_rhs -= firedrake.assemble(d2Fdmdu) @@ -282,7 +283,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): m = block_variable.output - if m is self.func: # and not self.linear + if m is self.func and not self.is_linear: return None if isinstance(m, firedrake.DirichletBC): @@ -298,7 +299,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_v continue if dep.tlm_value is None: continue - if dep.output is self.func: # and not self.linear + if dep.output is self.func and not self.is_linear: continue relevant_d2Fdm2_forms.append(self.d2Fdm2_adj_forms[idx][i]) @@ -338,6 +339,8 @@ def __init__(self, lhs, rhs, func, bcs, *args, **kwargs): # Solution function self.func = func self.function_space = self.func.function_space() + # Storage for adjoint solution of this block + self.adj_state_buf = func.copy(deepcopy=True) # Boundary conditions self.bcs = [] if bcs is not None: @@ -470,6 +473,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): dFdu_form, dJdu, compute_bdy ) self.adj_state = adj_sol + self.adj_state_buf.assign(adj_sol) if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 2d2e739191..e7b0e04298 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -346,6 +346,7 @@ def wrapper(self, **kwargs): block = CachedSolverBlock(self._ad_problem.u, self._ad_bcs, self._ad_solver_cache, + self._ad_problem.is_linear, self._ad_replaced_dependencies, self._ad_tlm_rhs, diff --git a/tests/firedrake/adjoint/test_hessian.py b/tests/firedrake/adjoint/test_hessian.py index 41242e7bd7..79142bf046 100644 --- a/tests/firedrake/adjoint/test_hessian.py +++ b/tests/firedrake/adjoint/test_hessian.py @@ -37,7 +37,7 @@ def test_simple_solve(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) u = TrialFunction(V) v = TestFunction(V) @@ -76,10 +76,10 @@ def test_mixed_derivatives(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) control_f = Control(f) - g = Function(V).assign(3) + g = Function(V).assign(3.) control_g = Control(g) u = TrialFunction(V) @@ -126,7 +126,7 @@ def test_function(rg): R = FunctionSpace(mesh, "R", 0) c = Function(R, val=4) control_c = Control(c) - f = Function(V).assign(3) + f = Function(V).assign(3.) control_f = Control(f) u = Function(V) @@ -139,14 +139,13 @@ def test_function(rg): J = assemble(c ** 2 * u ** 2 * dx) Jhat = ReducedFunctional(J, [control_c, control_f]) - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) # Step direction for derivatives and convergence test h_c = Function(R, val=1.0) h_f = rg.uniform(V, 0, 10) # Total derivative - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) + dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True) dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx) # Hessian @@ -163,7 +162,7 @@ def test_nonlinear(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(5) + f = Function(V).assign(5.) u = Function(V) v = TestFunction(V) @@ -176,6 +175,7 @@ def test_nonlinear(rg): Jhat = ReducedFunctional(J, Control(f)) h = rg.uniform(V, 0, 10) + g = f.copy(deepcopy=True) J.block_variable.adj_value = 1.0 f.block_variable.tlm_value = h @@ -186,8 +186,6 @@ def test_nonlinear(rg): J.block_variable.hessian_value = 0 tape.evaluate_hessian() - g = f.copy(deepcopy=True) - dJdm = J.block_variable.tlm_value Hm = f.block_variable.hessian_value.dat.inner(h.dat) assert taylor_test(Jhat, g, h, dJdm=dJdm, Hm=Hm) > 2.8 @@ -201,11 +199,11 @@ def test_dirichlet(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(30) + f = Function(V).assign(30.) u = Function(V) v = TestFunction(V) - c = Function(V).assign(1) + c = Function(V).assign(1.) bc = DirichletBC(V, c, "on_boundary") F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx From b89fa4447811a2a8ed21286e4c2cc957f4b5f582 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 14 Oct 2025 12:02:47 +0100 Subject: [PATCH 12/12] WIP: nlvs cached block args-kwargs --- firedrake/adjoint_utils/blocks/solving.py | 6 ++ firedrake/adjoint_utils/variational_solver.py | 44 ++++++++-- tests/firedrake/adjoint/test_assemble.py | 10 ++- tests/firedrake/adjoint/test_nlvs.py | 86 +++++++++++++------ 4 files changed, 108 insertions(+), 38 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index c819dd23e7..40913e774a 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -856,6 +856,12 @@ def solve_init_params(self, args, kwargs, varform): ) self.adj_kwargs.pop("appctx", None) + if hasattr(self, "tlm_args") and len(self.tlm_args) <= 0: + self.tlm_args = self.adj_args + + if hasattr(self, "tlm_kwargs") and len(self.tlm_kwargs) <= 0: + self.tlm_kwargs = self.adj_kwargs.copy() + solver_params = kwargs.get("solver_parameters", None) if solver_params is not None and "mat_type" in solver_params: self.assemble_kwargs["mat_type"] = solver_params["mat_type"] diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index e7b0e04298..6bbe65d8f3 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -1,10 +1,13 @@ import copy from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations -from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock, CachedSolverBlock +from firedrake.adjoint_utils.blocks import ( + NonlinearVariationalSolveBlock, CachedSolverBlock) +from firedrake.adjoint_utils.blocks.solving import solve_init_params from firedrake.ufl_expr import derivative, adjoint, action from ufl import replace, Action from ufl.algorithms import expand_derivatives +from types import SimpleNamespace class NonlinearVariationalProblemMixin: @@ -55,6 +58,19 @@ def wrapper(self, problem, *args, **kwargs): self._ad_solver_cache = {} + # process args/kwargs for cached solvers + self._ad_args_kwargs = SimpleNamespace( + forward_args=kwargs.pop("forward_args", []), + forward_kwargs=kwargs.pop("forward_kwargs", {}), + adj_args=kwargs.pop("adj_args", []), + adj_kwargs=kwargs.pop("adj_kwargs", {}), + tlm_args=kwargs.pop("tlm_args", []), + tlm_kwargs=kwargs.pop("tlm_kwargs", {}), + assemble_kwargs={} + ) + solve_init_params(self._ad_args_kwargs, + args, kwargs, varform=True) + return wrapper def _ad_cache_forward_solver(self): @@ -99,7 +115,10 @@ def _ad_cache_forward_solver(self): # This NLVS will be used to recompute the solve. # TODO: solver_parameters nlvp = NonlinearVariationalProblem(Fnew, unew, bcs=bcs_new) - nlvs = NonlinearVariationalSolver(nlvp) + nlvs = NonlinearVariationalSolver( + nlvp, + *self._ad_args_kwargs.forward_args, + **self._ad_args_kwargs.forward_kwargs) # The original coefficients will be added as # dependencies to all solve blocks. @@ -140,7 +159,10 @@ def _ad_cache_tlm_solver(self): # TODO: Think about if we should use new bcs. # TODO: solver_parameters lvp = LinearVariationalProblem(dFdu, dFdm, dudm, bcs=self._ad_bcs) - lvs = LinearVariationalSolver(lvp) + lvs = LinearVariationalSolver( + lvp, + *self._ad_args_kwargs.tlm_args, + **self._ad_args_kwargs.tlm_kwargs) self._ad_solver_cache[TLM] = lvs self._ad_tlm_rhs = dFdm @@ -202,7 +224,10 @@ def _ad_cache_adj_solver(self): # TODO: Think about if we should use new bcs. # TODO: solver_parameters lvp = LinearVariationalProblem(dFdu_adj, dJdu, adj_sol, bcs=self._ad_bcs) - lvs = LinearVariationalSolver(lvp) + lvs = LinearVariationalSolver( + lvp, + *self._ad_args_kwargs.adj_args, + **self._ad_args_kwargs.adj_kwargs) self._ad_solver_cache[ADJOINT] = lvs self._ad_adj_rhs = dJdu @@ -259,8 +284,13 @@ def _ad_cache_hessian_solver(self): # where dm is direction for tlm action so du/dm * dm is tlm output dFdu = self._ad_dFdu tlm_output = Function(V) - d2Fdu2 = expand_derivatives( - derivative(dFdu, u, tlm_output)) + d2Fdu2 = derivative(dFdu, u, tlm_output) + # print() + # print(f"{dFdu = }") + # print() + # print(f"{d2Fdu2 = }") + # print() + d2Fdu2 = expand_derivatives(d2Fdu2) self._ad_tlm_output = tlm_output @@ -370,6 +400,8 @@ def wrapper(self, **kwargs): for dep in self._ad_dependencies_to_add: block.add_dependency(dep, no_duplicates=True) + # mesh = self._ad_problem.u.function_space().mesh() + # block.add_dependency(mesh, no_duplicates=True) get_working_tape().add_block(block) diff --git a/tests/firedrake/adjoint/test_assemble.py b/tests/firedrake/adjoint/test_assemble.py index b10d68cfcc..9044fe7ab8 100644 --- a/tests/firedrake/adjoint/test_assemble.py +++ b/tests/firedrake/adjoint/test_assemble.py @@ -89,7 +89,7 @@ def test_assemble_1_forms_tlm(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) v = TestFunction(V) - f = Function(V).assign(1) + f = Function(V).assign(1.) w1 = assemble(inner(f, v) * dx) w2 = assemble(inner(f**2, v) * dx) @@ -101,9 +101,11 @@ def test_assemble_1_forms_tlm(rg): h = rg.uniform(V) g = f.copy(deepcopy=True) - f.block_variable.tlm_value = h - tape.evaluate_tlm() - assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9) + Jhat(g) + assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9) + # f.block_variable.tlm_value = h + # tape.evaluate_tlm() + # assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9) @pytest.mark.skipcomplex diff --git a/tests/firedrake/adjoint/test_nlvs.py b/tests/firedrake/adjoint/test_nlvs.py index 99fe61bbfc..fdb96e3da1 100644 --- a/tests/firedrake/adjoint/test_nlvs.py +++ b/tests/firedrake/adjoint/test_nlvs.py @@ -52,7 +52,7 @@ def forward(ic, dt, nt, bc_arg=None): solver.solve() nu += dt if bc_arg: - bc_val.assign(bc_val + dt) + bc_val.assign(bc_val + dt/nt) J = assemble(u1*u1*dx) return J @@ -68,14 +68,16 @@ def test_nlvs_adjoint(control_type, bc_type): if control_type == 'bc_control' and bc_type == 'neumann_bc': pytest.skip("Cannot use Neumann BCs as control") - mesh = UnitIntervalMesh(6) + nx = 100000 + nt = 50 + + mesh = UnitIntervalMesh(nx) x, = SpatialCoordinate(mesh) V = FunctionSpace(mesh, "CG", 1) R = FunctionSpace(mesh, "R", 0) - nt = 2 - dt = Function(R).assign(0.1) + dt = Function(R).assign(1/nx) ic = Function(V).interpolate(cos(2*pi*x)) dt0 = dt.copy(deepcopy=True) @@ -99,7 +101,7 @@ def test_nlvs_adjoint(control_type, bc_type): else: raise ValueError(f"Unrecognised {control_type=}") - print("record tape") + PETSc.Sys.Print("record tape") continue_annotation() with set_working_tape() as tape: J = forward(ic0, dt0, nt, bc_arg=bc_arg0) @@ -130,32 +132,60 @@ def test_nlvs_adjoint(control_type, bc_type): dt2 = dt bc_arg2 = m.copy(deepcopy=True) - # recompute component - print("recompute test") - assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 - - # tlm - print("tlm test") - Jhat(m) - assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 - - # adjoint - print("adjoint test") - Jhat(m) - assert taylor_test(Jhat, m, h) > 1.95 - - # hessian - print("hessian test") - Jhat(m) - taylor = taylor_to_dict(Jhat, m, h) - from pprint import pprint - pprint(taylor) - - assert min(taylor['R2']['Rate']) > 2.95 + from mpi4py import MPI + + # # recompute component + # PETSc.Sys.Print("recompute test") + # assert abs(Jhat(m) - forward(ic2, dt2, nt, bc_arg=bc_arg2)) < 1e-14 + + # # tlm + # PETSc.Sys.Print("tlm test") + # Jhat(m) + # assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + # # adjoint + # PETSc.Sys.Print("adjoint test") + # Jhat(m) + # assert taylor_test(Jhat, m, h) > 1.95 + + # # hessian + # PETSc.Sys.Print("hessian test") + # Jhat(m) + # taylor = taylor_to_dict(Jhat, m, h) + # from pprint import pprint + # pprint(taylor) + + # assert min(taylor['R0']['Rate']) > 0.95 + # assert min(taylor['R1']['Rate']) > 1.95 + # assert min(taylor['R2']['Rate']) > 2.95 + + for _ in range(3): + stime = MPI.Wtime() + Jhat(m) + etime = MPI.Wtime() + PETSc.Sys.Print(f"Recompute time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.derivative() + etime = MPI.Wtime() + PETSc.Sys.Print(f"Derivative time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.tlm(h) + etime = MPI.Wtime() + PETSc.Sys.Print(f"TLM time: {etime - stime:.4f}") + + for _ in range(3): + stime = MPI.Wtime() + Jhat.hessian(h, evaluate_tlm=False) + etime = MPI.Wtime() + PETSc.Sys.Print(f"Hessian time: {etime - stime:.4f}") if __name__ == "__main__": control_type = "ic_control" bc_type = "neumann_bc" - print(f"{control_type=} | {bc_type=}") + PETSc.Sys.Print(f"{control_type=} | {bc_type=}") test_nlvs_adjoint(control_type, bc_type)