Skip to content

Commit b5f8c0e

Browse files
committed
suggestions and tidy
removed GlobalWrapper class typo
1 parent c8f90bf commit b5f8c0e

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

firedrake/interpolation.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def __init__(self, expr: Interpolate):
209209
"""The dual argument slot of the Interpolate expression."""
210210
self.target_space = dual_arg.function_space().dual()
211211
"""The primal space we are interpolating into."""
212-
self.target_mesh = self.target_space.mesh().unique()
212+
# Delay calling .unique() because MixedInterpolator is fine with MeshSequence
213+
self.target_mesh = self.target_space.mesh()
213214
"""The domain we are interpolating into."""
214215
self.source_mesh = extract_unique_domain(operand) or self.target_mesh
215216
"""The domain we are interpolating from."""
@@ -254,16 +255,16 @@ def assemble(
254255
"""Assemble the interpolation. The result depends on the rank (number of arguments)
255256
of the :class:`Interpolate` expression:
256257
257-
* rank-2: assemble the operator and return a matrix
258-
* rank-1: assemble the action and return a function or cofunction
259-
* rank-0: assemble the action and return a scalar by applying the dual argument
258+
* rank 2: assemble the operator and return a matrix
259+
* rank 1: assemble the action and return a function or cofunction
260+
* rank 0: assemble the action and return a scalar by applying the dual argument
260261
261262
Parameters
262263
----------
263264
tensor
264-
Optional tensor to store the interpolated result. For rank-2
265+
Optional tensor to store the interpolated result. For rank 2
265266
expressions this is expected to be a subclass of
266-
:class:`~firedrake.matrix.MatrixBase`. For lower-rank expressions
267+
:class:`~firedrake.matrix.MatrixBase`. For rank 1 expressions
267268
this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`,
268269
for forward and adjoint interpolation respectively.
269270
bcs
@@ -316,7 +317,9 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
316317
try:
317318
source_mesh = extract_unique_domain(operand) or target_mesh
318319
except ValueError:
319-
raise NotImplementedError("Interpolating an expression defined on multiple meshes is not implemented yet.")
320+
raise NotImplementedError(
321+
"Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
322+
)
320323

321324
try:
322325
target_mesh = target_mesh.unique()
@@ -384,6 +387,7 @@ class CrossMeshInterpolator(Interpolator):
384387
@no_annotations
385388
def __init__(self, expr: Interpolate):
386389
super().__init__(expr)
390+
self.target_mesh = self.target_mesh.unique()
387391
if self.access and self.access != op2.WRITE:
388392
raise NotImplementedError(
389393
"Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -555,6 +559,7 @@ class SameMeshInterpolator(Interpolator):
555559
@no_annotations
556560
def __init__(self, expr):
557561
super().__init__(expr)
562+
self.target_mesh = self.target_mesh.unique()
558563
subset = self.subset
559564
if subset is None:
560565
target = self.target_mesh.topology
@@ -623,8 +628,8 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
623628

624629
def _get_callable(self, tensor=None, bcs=None):
625630
if (isinstance(tensor, Cofunction) and isinstance(self.dual_arg, Cofunction)) and set(tensor.dat).intersection(set(self.dual_arg.dat)):
626-
# adjoint one-form case: we need a zero tensor, so if it shares dats with
627-
# the dual_arg we cannot use it directly
631+
# adjoint one-form case: we need an empty tensor, so if it shares dats with
632+
# the dual_arg we cannot use it directly, so we store it
628633
f = self._get_tensor()
629634
copyout = (partial(f.dat.copy, tensor.dat),)
630635
else:
@@ -765,10 +770,14 @@ def _build_interpolation_callables(
765770
# Reconstruct the expression as an Interpolate
766771
V = expr.arguments()[-1].function_space().dual()
767772
expr = interpolate(zero(V.value_shape), V)
773+
768774
if not isinstance(expr, Interpolate):
769775
raise ValueError("Expecting to interpolate a symbolic Interpolate expression.")
776+
770777
dual_arg, operand = expr.argument_slots()
778+
assert isinstance(dual_arg, Cofunction | Coargument)
771779
V = dual_arg.function_space().dual()
780+
772781
try:
773782
to_element = create_element(V.ufl_element())
774783
except KeyError:
@@ -808,8 +817,7 @@ def _build_interpolation_callables(
808817
# For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
809818
# contributions from the facet DOFs of the dual argument.
810819
# The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity.
811-
needs_weight = isinstance(dual_arg, Cofunction) and not to_element.is_dg()
812-
if needs_weight:
820+
if isinstance(dual_arg, Cofunction) and not to_element.is_dg():
813821
# Create a buffer for the weighted Cofunction
814822
W = dual_arg.function_space()
815823
v = Function(W)
@@ -1184,14 +1192,6 @@ def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_c
11841192
)
11851193

11861194

1187-
class GlobalWrapper(object):
1188-
"""Wrapper object that fakes a Global to behave like a Function."""
1189-
def __init__(self, glob):
1190-
self.dat = glob
1191-
self.cell_node_map = lambda *arguments: None
1192-
self.ufl_domain = lambda: None
1193-
1194-
11951195
class VomOntoVomMat:
11961196
"""
11971197
Object that facilitates interpolation between a VertexOnlyMesh and its
@@ -1523,13 +1523,15 @@ def __init__(self, expr: Interpolate):
15231523
"""
15241524
super().__init__(expr)
15251525

1526-
def _get_sub_interpolators(self, bcs: Iterable[DirichletBC] | None = None) -> dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]:
1527-
"""Gets `Interpolator`s anf boundary conditions for each sub-Interpolate
1526+
def _get_sub_interpolators(
1527+
self, bcs: Iterable[DirichletBC] | None = None
1528+
) -> dict[tuple[int] | tuple[int, int], tuple[Interpolator, list[DirichletBC]]]:
1529+
"""Gets `Interpolator`s and boundary conditions for each sub-Interpolate
15281530
in the mixed expression.
15291531
15301532
Returns
15311533
-------
1532-
dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]
1534+
dict[tuple[int] | tuple[int, int], tuple[Interpolator, list[DirichletBC]]]
15331535
A map from block index tuples to `Interpolator`s and bcs.
15341536
"""
15351537
# Get the primal spaces
@@ -1549,7 +1551,7 @@ def _get_sub_interpolators(self, bcs: Iterable[DirichletBC] | None = None) -> di
15491551
self.ufl_interpolate = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space)
15501552

15511553
# Get sub-interpolators and sub-bcs for each block
1552-
Isub: dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {}
1554+
Isub: dict[tuple[int] | tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {}
15531555
for indices, form in split_form(self.ufl_interpolate):
15541556
if isinstance(form, ZeroBaseForm):
15551557
# Ensure block sparsity

firedrake/preconditioners/hiptmair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def coarsen(self, pc):
202202

203203
coarse_space_bcs = tuple(coarse_space_bcs)
204204
if G_callback is None:
205-
interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs).mat())
205+
interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs).petscmat)
206206
else:
207207
interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs)
208208

0 commit comments

Comments
 (0)