diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 2c1fb4f876..1ace4c5222 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -163,6 +163,13 @@ def _should_compute_boundary_adjoint(self, relevant_dependencies): def adj_sol(self): return self.adj_state + @adj_sol.setter + def adj_sol(self, value): + if self.adj_state is None: + self.adj_state = value.copy(deepcopy=True) + else: + self.adj_state.assign(value) + def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): fwd_block_variable = self.get_outputs()[0] u = fwd_block_variable.output @@ -187,7 +194,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq( dFdu_form, dJdu, compute_bdy ) - self.adj_state = adj_sol + self.adj_sol = adj_sol if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: @@ -408,7 +415,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, firedrake.derivative(dFdu_form, fwd_block_variable.saved_output, tlm_output)) - adj_sol = self.adj_state + adj_sol = self.adj_sol if adj_sol is None: raise RuntimeError("Hessian computation was run before adjoint.") bdy = self._should_compute_boundary_adjoint(relevant_dependencies) @@ -726,7 +733,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): relevant_dependencies ) adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) - self.adj_state = adj_sol + self.adj_sol = adj_sol if self.adj_cb is not None: self.adj_cb(adj_sol) if self.adj_bdy_cb is not None and compute_bdy: diff --git a/tests/firedrake/adjoint/test_hessian.py b/tests/firedrake/adjoint/test_hessian.py index 294680a554..344868b26e 100644 --- a/tests/firedrake/adjoint/test_hessian.py +++ b/tests/firedrake/adjoint/test_hessian.py @@ -37,7 +37,7 @@ def test_simple_solve(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) u = TrialFunction(V) v = TestFunction(V) @@ -76,10 +76,10 @@ def test_mixed_derivatives(rg): mesh = IntervalMesh(10, 0, 1) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(2) + f = Function(V).assign(2.) control_f = Control(f) - g = Function(V).assign(3) + g = Function(V).assign(3.) control_g = Control(g) u = TrialFunction(V) @@ -126,7 +126,7 @@ def test_function(rg): R = FunctionSpace(mesh, "R", 0) c = Function(R, val=4) control_c = Control(c) - f = Function(V).assign(3) + f = Function(V).assign(3.) control_f = Control(f) u = Function(V) @@ -139,14 +139,14 @@ def test_function(rg): J = assemble(c ** 2 * u ** 2 * dx) Jhat = ReducedFunctional(J, [control_c, control_f]) - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) + dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True) # Step direction for derivatives and convergence test h_c = Function(R, val=1.0) h_f = rg.uniform(V, 0, 10) # Total derivative - dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True) + dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True) dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx) # Hessian @@ -163,7 +163,7 @@ def test_nonlinear(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) R = FunctionSpace(mesh, "R", 0) - f = Function(V).assign(5) + f = Function(V).assign(5.) u = Function(V) v = TestFunction(V) @@ -201,11 +201,11 @@ def test_dirichlet(rg): mesh = UnitSquareMesh(10, 10) V = FunctionSpace(mesh, "Lagrange", 1) - f = Function(V).assign(30) + f = Function(V).assign(30.) u = Function(V) v = TestFunction(V) - c = Function(V).assign(1) + c = Function(V).assign(1.) bc = DirichletBC(V, c, "on_boundary") F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx @@ -249,13 +249,14 @@ def Dt(u, u_, timestep): pr = project(sin(2*pi*x), V, annotate=False) ic = Function(V).assign(pr) - u_ = Function(V) - u = Function(V) + u_ = Function(V).assign(ic) + u = Function(V).assign(ic) v = TestFunction(V) nu = Constant(0.0001) - timestep = Constant(1.0/n) + dt = 0.01 + nt = 20 params = { 'snes_rtol': 1e-10, @@ -263,10 +264,10 @@ def Dt(u, u_, timestep): 'pc_type': 'lu', } - F = (Dt(u, ic, timestep)*v + F = (Dt(u, u_, dt)*v + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx + bc = DirichletBC(V, 0.0, "on_boundary") - t = 0.0 if solve_type == "nlvs": use_nlvs = True @@ -285,21 +286,14 @@ def Dt(u, u_, timestep): else: solve(F == 0, u, bc, solver_parameters=params) u_.assign(u) - t += float(timestep) - F = (Dt(u, u_, timestep)*v - + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx - - end = 0.2 - while (t <= end): + for _ in range(nt): if use_nlvs: solver.solve() else: solve(F == 0, u, bc, solver_parameters=params) u_.assign(u) - t += float(timestep) - J = assemble(u_*u_*dx + ic*ic*dx) Jhat = ReducedFunctional(J, Control(ic)) diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index 03557bf435..cc4f1ade43 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -124,7 +124,7 @@ def test_interpolate_vector_valued(): J = assemble(inner(f, g)*u**2*dx) rf = ReducedFunctional(J, Control(f)) - h = Function(V1).assign(1) + h = Function(V1).assign(1.) assert taylor_test(rf, f, h) > 1.9 @@ -144,7 +144,7 @@ def test_interpolate_tlm(): J = assemble(inner(f, g)*u**2*dx) rf = ReducedFunctional(J, Control(f)) - h = Function(V1).assign(1) + h = Function(V1).assign(1.) f.block_variable.tlm_value = h tape = get_working_tape() @@ -259,7 +259,7 @@ def test_interpolate_to_function_space_cross_mesh(): mesh_src = UnitSquareMesh(2, 2) mesh_dest = UnitSquareMesh(3, 3, quadrilateral=True) V = FunctionSpace(mesh_src, "CG", 1) - W = FunctionSpace(mesh_dest, "DG", 1) + W = FunctionSpace(mesh_dest, "DQ", 1) R = FunctionSpace(mesh_src, "R", 0) u = Function(V) @@ -290,7 +290,7 @@ def test_interpolate_hessian_linear_expr(rg): # space h and perterbation direction g. W = FunctionSpace(mesh, "Lagrange", 2) R = FunctionSpace(mesh, "R", 0) - f = Function(W).assign(5) + f = Function(W).assign(5.) # Note that we interpolate from a linear expression expr_interped = Function(V).interpolate(2*f) @@ -345,7 +345,7 @@ def test_interpolate_hessian_nonlinear_expr(rg): # space h and perterbation direction g. W = FunctionSpace(mesh, "Lagrange", 2) R = FunctionSpace(mesh, "R", 0) - f = Function(W).assign(5) + f = Function(W).assign(5.) # Note that we interpolate from a nonlinear expression expr_interped = Function(V).interpolate(f**2) @@ -400,8 +400,8 @@ def test_interpolate_hessian_nonlinear_expr_multi(rg): # space h and perterbation direction g. W = FunctionSpace(mesh, "Lagrange", 2) R = FunctionSpace(mesh, "R", 0) - f = Function(W).assign(5) - w = Function(W).assign(4) + f = Function(W).assign(5.) + w = Function(W).assign(4.) c = Function(R, val=2.0) # Note that we interpolate from a nonlinear expression with 3 coefficients expr_interped = Function(V).interpolate(f**2+w**2+c**2) @@ -460,8 +460,8 @@ def test_interpolate_hessian_nonlinear_expr_multi_cross_mesh(rg): mesh_src = UnitSquareMesh(11, 11) R_src = FunctionSpace(mesh_src, "R", 0) W = FunctionSpace(mesh_src, "Lagrange", 2) - f = Function(W).assign(5) - w = Function(W).assign(4) + f = Function(W).assign(5.) + w = Function(W).assign(4.) c = Function(R_src, val=2.0) # Note that we interpolate from a nonlinear expression with 3 coefficients expr_interped = Function(V).interpolate(f**2+w**2+c**2) @@ -1035,8 +1035,9 @@ def u_analytical(x, a, b): tape = get_working_tape() # Check the checkpointed boundary conditions are not updating the # user-defined boundary conditions ``bc_left`` and ``bc_right``. - assert isinstance(tape._blocks[0], DirichletBCBlock) and \ - tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg + assert isinstance(tape._blocks[0], DirichletBCBlock) + assert tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg + # tape._blocks[1] is the DirichletBC block for the right boundary - assert isinstance(tape._blocks[1], DirichletBCBlock) and \ - tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg + assert isinstance(tape._blocks[1], DirichletBCBlock) + assert tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg