@@ -209,7 +209,8 @@ def __init__(self, expr: Interpolate):
209209 """The dual argument slot of the Interpolate expression."""
210210 self .target_space = dual_arg .function_space ().dual ()
211211 """The primal space we are interpolating into."""
212- self .target_mesh = self .target_space .mesh ().unique ()
212+ # Delay calling .unique() because MixedInterpolator is fine with MeshSequence
213+ self .target_mesh = self .target_space .mesh ()
213214 """The domain we are interpolating into."""
214215 self .source_mesh = extract_unique_domain (operand ) or self .target_mesh
215216 """The domain we are interpolating from."""
@@ -254,16 +255,16 @@ def assemble(
254255 """Assemble the interpolation. The result depends on the rank (number of arguments)
255256 of the :class:`Interpolate` expression:
256257
257- * rank- 2: assemble the operator and return a matrix
258- * rank- 1: assemble the action and return a function or cofunction
259- * rank- 0: assemble the action and return a scalar by applying the dual argument
258+ * rank 2: assemble the operator and return a matrix
259+ * rank 1: assemble the action and return a function or cofunction
260+ * rank 0: assemble the action and return a scalar by applying the dual argument
260261
261262 Parameters
262263 ----------
263264 tensor
264- Optional tensor to store the interpolated result. For rank- 2
265+ Optional tensor to store the interpolated result. For rank 2
265266 expressions this is expected to be a subclass of
266- :class:`~firedrake.matrix.MatrixBase`. For lower- rank expressions
267+ :class:`~firedrake.matrix.MatrixBase`. For rank 1 expressions
267268 this is a :class:`~firedrake.function.Function` or :class:`~firedrake.cofunction.Cofunction`,
268269 for forward and adjoint interpolation respectively.
269270 bcs
@@ -316,7 +317,9 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
316317 try :
317318 source_mesh = extract_unique_domain (operand ) or target_mesh
318319 except ValueError :
319- raise NotImplementedError ("Interpolating an expression defined on multiple meshes is not implemented yet." )
320+ raise NotImplementedError (
321+ "Interpolating an expression with no arguments defined on multiple meshes is not implemented yet."
322+ )
320323
321324 try :
322325 target_mesh = target_mesh .unique ()
@@ -384,6 +387,7 @@ class CrossMeshInterpolator(Interpolator):
384387 @no_annotations
385388 def __init__ (self , expr : Interpolate ):
386389 super ().__init__ (expr )
390+ self .target_mesh = self .target_mesh .unique ()
387391 if self .access and self .access != op2 .WRITE :
388392 raise NotImplementedError (
389393 "Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -555,6 +559,7 @@ class SameMeshInterpolator(Interpolator):
555559 @no_annotations
556560 def __init__ (self , expr ):
557561 super ().__init__ (expr )
562+ self .target_mesh = self .target_mesh .unique ()
558563 subset = self .subset
559564 if subset is None :
560565 target = self .target_mesh .topology
@@ -623,8 +628,8 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
623628
624629 def _get_callable (self , tensor = None , bcs = None ):
625630 if (isinstance (tensor , Cofunction ) and isinstance (self .dual_arg , Cofunction )) and set (tensor .dat ).intersection (set (self .dual_arg .dat )):
626- # adjoint one-form case: we need a zero tensor, so if it shares dats with
627- # the dual_arg we cannot use it directly
631+ # adjoint one-form case: we need an empty tensor, so if it shares dats with
632+ # the dual_arg we cannot use it directly, so we store it
628633 f = self ._get_tensor ()
629634 copyout = (partial (f .dat .copy , tensor .dat ),)
630635 else :
@@ -765,10 +770,14 @@ def _build_interpolation_callables(
765770 # Reconstruct the expression as an Interpolate
766771 V = expr .arguments ()[- 1 ].function_space ().dual ()
767772 expr = interpolate (zero (V .value_shape ), V )
773+
768774 if not isinstance (expr , Interpolate ):
769775 raise ValueError ("Expecting to interpolate a symbolic Interpolate expression." )
776+
770777 dual_arg , operand = expr .argument_slots ()
778+ assert isinstance (dual_arg , Cofunction | Coargument )
771779 V = dual_arg .function_space ().dual ()
780+
772781 try :
773782 to_element = create_element (V .ufl_element ())
774783 except KeyError :
@@ -808,8 +817,7 @@ def _build_interpolation_callables(
808817 # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
809818 # contributions from the facet DOFs of the dual argument.
810819 # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity.
811- needs_weight = isinstance (dual_arg , Cofunction ) and not to_element .is_dg ()
812- if needs_weight :
820+ if isinstance (dual_arg , Cofunction ) and not to_element .is_dg ():
813821 # Create a buffer for the weighted Cofunction
814822 W = dual_arg .function_space ()
815823 v = Function (W )
@@ -1184,14 +1192,6 @@ def vom_cell_parent_node_map_extruded(vertex_only_mesh: MeshGeometry, extruded_c
11841192 )
11851193
11861194
1187- class GlobalWrapper (object ):
1188- """Wrapper object that fakes a Global to behave like a Function."""
1189- def __init__ (self , glob ):
1190- self .dat = glob
1191- self .cell_node_map = lambda * arguments : None
1192- self .ufl_domain = lambda : None
1193-
1194-
11951195class VomOntoVomMat :
11961196 """
11971197 Object that facilitates interpolation between a VertexOnlyMesh and its
@@ -1523,13 +1523,15 @@ def __init__(self, expr: Interpolate):
15231523 """
15241524 super ().__init__ (expr )
15251525
1526- def _get_sub_interpolators (self , bcs : Iterable [DirichletBC ] | None = None ) -> dict [tuple [int , int ], tuple [Interpolator , list [DirichletBC ]]]:
1527- """Gets `Interpolator`s anf boundary conditions for each sub-Interpolate
1526+ def _get_sub_interpolators (
1527+ self , bcs : Iterable [DirichletBC ] | None = None
1528+ ) -> dict [tuple [int ] | tuple [int , int ], tuple [Interpolator , list [DirichletBC ]]]:
1529+ """Gets `Interpolator`s and boundary conditions for each sub-Interpolate
15281530 in the mixed expression.
15291531
15301532 Returns
15311533 -------
1532- dict[tuple[int, int], tuple[Interpolator, list[DirichletBC]]]
1534+ dict[tuple[int] | tuple[int , int], tuple[Interpolator, list[DirichletBC]]]
15331535 A map from block index tuples to `Interpolator`s and bcs.
15341536 """
15351537 # Get the primal spaces
@@ -1549,7 +1551,7 @@ def _get_sub_interpolators(self, bcs: Iterable[DirichletBC] | None = None) -> di
15491551 self .ufl_interpolate = self .ufl_interpolate ._ufl_expr_reconstruct_ (self .operand , self .target_space )
15501552
15511553 # Get sub-interpolators and sub-bcs for each block
1552- Isub : dict [tuple [int , int ], tuple [Interpolator , list [DirichletBC ]]] = {}
1554+ Isub : dict [tuple [int ] | tuple [ int , int ], tuple [Interpolator , list [DirichletBC ]]] = {}
15531555 for indices , form in split_form (self .ufl_interpolate ):
15541556 if isinstance (form , ZeroBaseForm ):
15551557 # Ensure block sparsity
0 commit comments