Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def _should_compute_boundary_adjoint(self, relevant_dependencies):
def adj_sol(self):
return self.adj_state

@adj_sol.setter
def adj_sol(self, value):
if self.adj_state is None:
self.adj_state = value.copy(deepcopy=True)
else:
self.adj_state.assign(value)

def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
fwd_block_variable = self.get_outputs()[0]
u = fwd_block_variable.output
Expand All @@ -187,7 +194,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
dFdu_form, dJdu, compute_bdy
)
self.adj_state = adj_sol
self.adj_sol = adj_sol
if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
Expand Down Expand Up @@ -408,7 +415,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
firedrake.derivative(dFdu_form, fwd_block_variable.saved_output,
tlm_output))

adj_sol = self.adj_state
adj_sol = self.adj_sol
if adj_sol is None:
raise RuntimeError("Hessian computation was run before adjoint.")
bdy = self._should_compute_boundary_adjoint(relevant_dependencies)
Expand Down Expand Up @@ -726,7 +733,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
relevant_dependencies
)
adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy)
self.adj_state = adj_sol
self.adj_sol = adj_sol
if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
Expand Down
38 changes: 16 additions & 22 deletions tests/firedrake/adjoint/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_simple_solve(rg):
mesh = IntervalMesh(10, 0, 1)
V = FunctionSpace(mesh, "Lagrange", 1)

f = Function(V).assign(2)
f = Function(V).assign(2.)

u = TrialFunction(V)
v = TestFunction(V)
Expand Down Expand Up @@ -76,10 +76,10 @@ def test_mixed_derivatives(rg):
mesh = IntervalMesh(10, 0, 1)
V = FunctionSpace(mesh, "Lagrange", 1)

f = Function(V).assign(2)
f = Function(V).assign(2.)
control_f = Control(f)

g = Function(V).assign(3)
g = Function(V).assign(3.)
control_g = Control(g)

u = TrialFunction(V)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_function(rg):
R = FunctionSpace(mesh, "R", 0)
c = Function(R, val=4)
control_c = Control(c)
f = Function(V).assign(3)
f = Function(V).assign(3.)
control_f = Control(f)

u = Function(V)
Expand All @@ -139,14 +139,14 @@ def test_function(rg):
J = assemble(c ** 2 * u ** 2 * dx)

Jhat = ReducedFunctional(J, [control_c, control_f])
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)

# Step direction for derivatives and convergence test
h_c = Function(R, val=1.0)
h_f = rg.uniform(V, 0, 10)

# Total derivative
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)
dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx)

# Hessian
Expand All @@ -163,7 +163,7 @@ def test_nonlinear(rg):
mesh = UnitSquareMesh(10, 10)
V = FunctionSpace(mesh, "Lagrange", 1)
R = FunctionSpace(mesh, "R", 0)
f = Function(V).assign(5)
f = Function(V).assign(5.)

u = Function(V)
v = TestFunction(V)
Expand Down Expand Up @@ -201,11 +201,11 @@ def test_dirichlet(rg):
mesh = UnitSquareMesh(10, 10)
V = FunctionSpace(mesh, "Lagrange", 1)

f = Function(V).assign(30)
f = Function(V).assign(30.)

u = Function(V)
v = TestFunction(V)
c = Function(V).assign(1)
c = Function(V).assign(1.)
bc = DirichletBC(V, c, "on_boundary")

F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx
Expand Down Expand Up @@ -249,24 +249,25 @@ def Dt(u, u_, timestep):
pr = project(sin(2*pi*x), V, annotate=False)
ic = Function(V).assign(pr)

u_ = Function(V)
u = Function(V)
u_ = Function(V).assign(ic)
u = Function(V).assign(ic)
v = TestFunction(V)

nu = Constant(0.0001)

timestep = Constant(1.0/n)
dt = 0.01
nt = 20

params = {
'snes_rtol': 1e-10,
'ksp_type': 'preonly',
'pc_type': 'lu',
}

F = (Dt(u, ic, timestep)*v
F = (Dt(u, u_, dt)*v
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx

bc = DirichletBC(V, 0.0, "on_boundary")
t = 0.0

if solve_type == "nlvs":
use_nlvs = True
Expand All @@ -285,21 +286,14 @@ def Dt(u, u_, timestep):
else:
solve(F == 0, u, bc, solver_parameters=params)
u_.assign(u)
t += float(timestep)

F = (Dt(u, u_, timestep)*v
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx

end = 0.2
while (t <= end):
for _ in range(nt):
if use_nlvs:
solver.solve()
else:
solve(F == 0, u, bc, solver_parameters=params)
u_.assign(u)

t += float(timestep)

J = assemble(u_*u_*dx + ic*ic*dx)

Jhat = ReducedFunctional(J, Control(ic))
Expand Down
27 changes: 14 additions & 13 deletions tests/firedrake/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_interpolate_vector_valued():
J = assemble(inner(f, g)*u**2*dx)
rf = ReducedFunctional(J, Control(f))

h = Function(V1).assign(1)
h = Function(V1).assign(1.)
assert taylor_test(rf, f, h) > 1.9


Expand All @@ -144,7 +144,7 @@ def test_interpolate_tlm():
J = assemble(inner(f, g)*u**2*dx)
rf = ReducedFunctional(J, Control(f))

h = Function(V1).assign(1)
h = Function(V1).assign(1.)
f.block_variable.tlm_value = h

tape = get_working_tape()
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_interpolate_to_function_space_cross_mesh():
mesh_src = UnitSquareMesh(2, 2)
mesh_dest = UnitSquareMesh(3, 3, quadrilateral=True)
V = FunctionSpace(mesh_src, "CG", 1)
W = FunctionSpace(mesh_dest, "DG", 1)
W = FunctionSpace(mesh_dest, "DQ", 1)
R = FunctionSpace(mesh_src, "R", 0)
u = Function(V)

Expand Down Expand Up @@ -290,7 +290,7 @@ def test_interpolate_hessian_linear_expr(rg):
# space h and perterbation direction g.
W = FunctionSpace(mesh, "Lagrange", 2)
R = FunctionSpace(mesh, "R", 0)
f = Function(W).assign(5)
f = Function(W).assign(5.)
# Note that we interpolate from a linear expression
expr_interped = Function(V).interpolate(2*f)

Expand Down Expand Up @@ -345,7 +345,7 @@ def test_interpolate_hessian_nonlinear_expr(rg):
# space h and perterbation direction g.
W = FunctionSpace(mesh, "Lagrange", 2)
R = FunctionSpace(mesh, "R", 0)
f = Function(W).assign(5)
f = Function(W).assign(5.)
# Note that we interpolate from a nonlinear expression
expr_interped = Function(V).interpolate(f**2)

Expand Down Expand Up @@ -400,8 +400,8 @@ def test_interpolate_hessian_nonlinear_expr_multi(rg):
# space h and perterbation direction g.
W = FunctionSpace(mesh, "Lagrange", 2)
R = FunctionSpace(mesh, "R", 0)
f = Function(W).assign(5)
w = Function(W).assign(4)
f = Function(W).assign(5.)
w = Function(W).assign(4.)
c = Function(R, val=2.0)
# Note that we interpolate from a nonlinear expression with 3 coefficients
expr_interped = Function(V).interpolate(f**2+w**2+c**2)
Expand Down Expand Up @@ -460,8 +460,8 @@ def test_interpolate_hessian_nonlinear_expr_multi_cross_mesh(rg):
mesh_src = UnitSquareMesh(11, 11)
R_src = FunctionSpace(mesh_src, "R", 0)
W = FunctionSpace(mesh_src, "Lagrange", 2)
f = Function(W).assign(5)
w = Function(W).assign(4)
f = Function(W).assign(5.)
w = Function(W).assign(4.)
c = Function(R_src, val=2.0)
# Note that we interpolate from a nonlinear expression with 3 coefficients
expr_interped = Function(V).interpolate(f**2+w**2+c**2)
Expand Down Expand Up @@ -1035,8 +1035,9 @@ def u_analytical(x, a, b):
tape = get_working_tape()
# Check the checkpointed boundary conditions are not updating the
# user-defined boundary conditions ``bc_left`` and ``bc_right``.
assert isinstance(tape._blocks[0], DirichletBCBlock) and \
tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg
assert isinstance(tape._blocks[0], DirichletBCBlock)
assert tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg

# tape._blocks[1] is the DirichletBC block for the right boundary
assert isinstance(tape._blocks[1], DirichletBCBlock) and \
tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg
assert isinstance(tape._blocks[1], DirichletBCBlock)
assert tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg